Add files using upload-large-folder tool
Browse files- LTA_openwebtext_dualt/logs/fullycoupled_outwd0p5_8gpu/lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_20260514_215642.log +0 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/cohere/modeling_cohere.py +530 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/__init__.py +27 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py +95 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py +800 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py +154 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/__init__.py +28 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/configuration_instructblip.py +186 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/modeling_instructblip.py +1405 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/processing_instructblip.py +123 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mllama/__init__.py +30 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py +963 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/configuration_speecht5.py +279 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/modeling_speecht5.py +0 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/number_normalizer.py +191 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/tokenization_speecht5.py +166 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_053000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_163000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_172000.pt +3 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_182000.pt +3 -0
LTA_openwebtext_dualt/logs/fullycoupled_outwd0p5_8gpu/lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_20260514_215642.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/cohere/modeling_cohere.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/cohere/modular_cohere.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_cohere.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# Copyright 2024 Cohere team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 10 |
+
# and OPT implementations in this library. It has been modified from its
|
| 11 |
+
# original forms to accommodate minor architectural differences compared
|
| 12 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 13 |
+
#
|
| 14 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 15 |
+
# you may not use this file except in compliance with the License.
|
| 16 |
+
# You may obtain a copy of the License at
|
| 17 |
+
#
|
| 18 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 19 |
+
#
|
| 20 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 21 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 22 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 23 |
+
# See the License for the specific language governing permissions and
|
| 24 |
+
# limitations under the License.
|
| 25 |
+
|
| 26 |
+
# This file is based on the LLama model definition file in transformers
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
from collections.abc import Callable
|
| 30 |
+
from typing import Optional
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
from torch import nn
|
| 34 |
+
|
| 35 |
+
from ...activations import ACT2FN
|
| 36 |
+
from ...cache_utils import Cache, DynamicCache
|
| 37 |
+
from ...generation import GenerationMixin
|
| 38 |
+
from ...integrations import use_kernelized_func
|
| 39 |
+
from ...masking_utils import create_causal_mask
|
| 40 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 41 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 42 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 43 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 44 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 45 |
+
from ...processing_utils import Unpack
|
| 46 |
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 47 |
+
from ...utils.generic import maybe_autocast, merge_with_config_defaults
|
| 48 |
+
from ...utils.output_capturing import capture_outputs
|
| 49 |
+
from .configuration_cohere import CohereConfig
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class CohereLayerNorm(nn.Module):
|
| 53 |
+
def __init__(self, hidden_size=None, eps=1e-5, bias=False):
|
| 54 |
+
"""The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 57 |
+
self.variance_epsilon = eps
|
| 58 |
+
|
| 59 |
+
def forward(self, hidden_states):
|
| 60 |
+
input_dtype = hidden_states.dtype
|
| 61 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 62 |
+
mean = hidden_states.mean(-1, keepdim=True)
|
| 63 |
+
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
| 64 |
+
hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
|
| 65 |
+
hidden_states = self.weight.to(torch.float32) * hidden_states
|
| 66 |
+
return hidden_states.to(input_dtype)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class CohereRotaryEmbedding(nn.Module):
|
| 70 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 71 |
+
|
| 72 |
+
def __init__(self, config: CohereConfig, device=None):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 75 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 76 |
+
|
| 77 |
+
self.config = config
|
| 78 |
+
|
| 79 |
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
| 80 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 81 |
+
if self.rope_type != "default":
|
| 82 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 83 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 84 |
+
|
| 85 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 86 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def compute_default_rope_parameters(
|
| 90 |
+
config: CohereConfig | None = None,
|
| 91 |
+
device: Optional["torch.device"] = None,
|
| 92 |
+
seq_len: int | None = None,
|
| 93 |
+
) -> tuple["torch.Tensor", float]:
|
| 94 |
+
"""
|
| 95 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 96 |
+
Args:
|
| 97 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 98 |
+
The model configuration.
|
| 99 |
+
device (`torch.device`):
|
| 100 |
+
The device to use for initialization of the inverse frequencies.
|
| 101 |
+
seq_len (`int`, *optional*):
|
| 102 |
+
The current sequence length. Unused for this type of RoPE.
|
| 103 |
+
Returns:
|
| 104 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 105 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 106 |
+
"""
|
| 107 |
+
base = config.rope_parameters["rope_theta"]
|
| 108 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 109 |
+
|
| 110 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 111 |
+
|
| 112 |
+
# Compute the inverse frequencies
|
| 113 |
+
inv_freq = 1.0 / (
|
| 114 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 115 |
+
)
|
| 116 |
+
return inv_freq, attention_factor
|
| 117 |
+
|
| 118 |
+
@torch.no_grad()
|
| 119 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 120 |
+
def forward(self, x, position_ids):
|
| 121 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 122 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 123 |
+
|
| 124 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 125 |
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 126 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 127 |
+
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
|
| 128 |
+
cos = emb.cos() * self.attention_scaling
|
| 129 |
+
sin = emb.sin() * self.attention_scaling
|
| 130 |
+
|
| 131 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class CohereMLP(nn.Module):
|
| 135 |
+
def __init__(self, config):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.config = config
|
| 138 |
+
self.hidden_size = config.hidden_size
|
| 139 |
+
self.intermediate_size = config.intermediate_size
|
| 140 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 141 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 142 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 143 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 147 |
+
return down_proj
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 151 |
+
"""
|
| 152 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 153 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 154 |
+
"""
|
| 155 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 156 |
+
if n_rep == 1:
|
| 157 |
+
return hidden_states
|
| 158 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 159 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def eager_attention_forward(
|
| 163 |
+
module: nn.Module,
|
| 164 |
+
query: torch.Tensor,
|
| 165 |
+
key: torch.Tensor,
|
| 166 |
+
value: torch.Tensor,
|
| 167 |
+
attention_mask: torch.Tensor | None,
|
| 168 |
+
scaling: float,
|
| 169 |
+
dropout: float = 0.0,
|
| 170 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 171 |
+
):
|
| 172 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 173 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 174 |
+
|
| 175 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 176 |
+
if attention_mask is not None:
|
| 177 |
+
attn_weights = attn_weights + attention_mask
|
| 178 |
+
|
| 179 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 180 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 181 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 182 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 183 |
+
|
| 184 |
+
return attn_output, attn_weights
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def rotate_half(x):
|
| 188 |
+
# Split and rotate. Note that this function is different from e.g. Llama.
|
| 189 |
+
x1 = x[..., ::2]
|
| 190 |
+
x2 = x[..., 1::2]
|
| 191 |
+
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
|
| 192 |
+
return rot_x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 196 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
q (`torch.Tensor`): The query tensor.
|
| 200 |
+
k (`torch.Tensor`): The key tensor.
|
| 201 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 202 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 203 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 204 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 205 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 206 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 207 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 208 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 209 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 210 |
+
Returns:
|
| 211 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 212 |
+
"""
|
| 213 |
+
dtype = q.dtype
|
| 214 |
+
q = q.float()
|
| 215 |
+
k = k.float()
|
| 216 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 217 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 218 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 219 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 220 |
+
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@use_kernelized_func(apply_rotary_pos_emb)
|
| 224 |
+
class CohereAttention(nn.Module):
|
| 225 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 226 |
+
|
| 227 |
+
def __init__(self, config: CohereConfig, layer_idx: int | None = None):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.config = config
|
| 230 |
+
self.layer_idx = layer_idx
|
| 231 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 232 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 233 |
+
self.scaling = self.head_dim**-0.5
|
| 234 |
+
self.attention_dropout = config.attention_dropout
|
| 235 |
+
self.is_causal = True
|
| 236 |
+
|
| 237 |
+
self.q_proj = nn.Linear(
|
| 238 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 239 |
+
)
|
| 240 |
+
self.k_proj = nn.Linear(
|
| 241 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 242 |
+
)
|
| 243 |
+
self.v_proj = nn.Linear(
|
| 244 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 245 |
+
)
|
| 246 |
+
self.o_proj = nn.Linear(
|
| 247 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 248 |
+
)
|
| 249 |
+
self.use_qk_norm = config.use_qk_norm
|
| 250 |
+
if self.use_qk_norm:
|
| 251 |
+
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
|
| 252 |
+
self.q_norm = CohereLayerNorm(
|
| 253 |
+
hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps
|
| 254 |
+
)
|
| 255 |
+
self.k_norm = CohereLayerNorm(
|
| 256 |
+
hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def forward(
|
| 260 |
+
self,
|
| 261 |
+
hidden_states: torch.Tensor,
|
| 262 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 263 |
+
attention_mask: torch.Tensor | None,
|
| 264 |
+
past_key_values: Cache | None = None,
|
| 265 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 266 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 267 |
+
input_shape = hidden_states.shape[:-1]
|
| 268 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 269 |
+
|
| 270 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
| 271 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
| 272 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape)
|
| 273 |
+
|
| 274 |
+
if self.use_qk_norm: # main diff from Llama
|
| 275 |
+
query_states = self.q_norm(query_states)
|
| 276 |
+
key_states = self.k_norm(key_states)
|
| 277 |
+
|
| 278 |
+
query_states = query_states.transpose(1, 2)
|
| 279 |
+
key_states = key_states.transpose(1, 2)
|
| 280 |
+
value_states = value_states.transpose(1, 2)
|
| 281 |
+
|
| 282 |
+
cos, sin = position_embeddings
|
| 283 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 284 |
+
|
| 285 |
+
if past_key_values is not None:
|
| 286 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
| 287 |
+
|
| 288 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 289 |
+
self.config._attn_implementation, eager_attention_forward
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
attn_output, attn_weights = attention_interface(
|
| 293 |
+
self,
|
| 294 |
+
query_states,
|
| 295 |
+
key_states,
|
| 296 |
+
value_states,
|
| 297 |
+
attention_mask,
|
| 298 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 299 |
+
scaling=self.scaling,
|
| 300 |
+
**kwargs,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 304 |
+
attn_output = self.o_proj(attn_output)
|
| 305 |
+
return attn_output, attn_weights
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class CohereDecoderLayer(GradientCheckpointingLayer):
|
| 309 |
+
def __init__(self, config: CohereConfig, layer_idx: int):
|
| 310 |
+
super().__init__()
|
| 311 |
+
self.hidden_size = config.hidden_size
|
| 312 |
+
self.self_attn = CohereAttention(config=config, layer_idx=layer_idx)
|
| 313 |
+
self.mlp = CohereMLP(config)
|
| 314 |
+
self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
|
| 315 |
+
|
| 316 |
+
def forward(
|
| 317 |
+
self,
|
| 318 |
+
hidden_states: torch.Tensor,
|
| 319 |
+
attention_mask: torch.Tensor | None = None,
|
| 320 |
+
position_ids: torch.LongTensor | None = None,
|
| 321 |
+
past_key_values: Cache | None = None,
|
| 322 |
+
use_cache: bool | None = False,
|
| 323 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 324 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 325 |
+
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
| 326 |
+
"""
|
| 327 |
+
Args:
|
| 328 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 329 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 330 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 331 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 332 |
+
past_key_values (`Cache`, *optional*): cached past key and value projection states
|
| 333 |
+
output_attentions (`bool`, *optional*):
|
| 334 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 335 |
+
returned tensors for more detail.
|
| 336 |
+
use_cache (`bool`, *optional*):
|
| 337 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 338 |
+
(see `past_key_values`).
|
| 339 |
+
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 340 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 341 |
+
with `head_dim` being the embedding dimension of each attention head.
|
| 342 |
+
"""
|
| 343 |
+
residual = hidden_states
|
| 344 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 345 |
+
|
| 346 |
+
hidden_states_attention, _ = self.self_attn(
|
| 347 |
+
hidden_states=hidden_states,
|
| 348 |
+
attention_mask=attention_mask,
|
| 349 |
+
position_ids=position_ids,
|
| 350 |
+
past_key_values=past_key_values,
|
| 351 |
+
use_cache=use_cache,
|
| 352 |
+
position_embeddings=position_embeddings,
|
| 353 |
+
**kwargs,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
hidden_states_mlp = self.mlp(hidden_states)
|
| 357 |
+
hidden_states = residual + hidden_states_attention + hidden_states_mlp
|
| 358 |
+
return hidden_states
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
@auto_docstring
|
| 362 |
+
class CoherePreTrainedModel(PreTrainedModel):
|
| 363 |
+
config: CohereConfig
|
| 364 |
+
base_model_prefix = "model"
|
| 365 |
+
supports_gradient_checkpointing = True
|
| 366 |
+
_no_split_modules = ["CohereDecoderLayer"]
|
| 367 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 368 |
+
_supports_flash_attn = True
|
| 369 |
+
_supports_sdpa = True
|
| 370 |
+
_supports_flex_attn = True
|
| 371 |
+
|
| 372 |
+
_can_compile_fullgraph = True
|
| 373 |
+
_supports_attention_backend = True
|
| 374 |
+
_can_record_outputs = {
|
| 375 |
+
"hidden_states": CohereDecoderLayer,
|
| 376 |
+
"attentions": CohereAttention,
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@auto_docstring
|
| 381 |
+
class CohereModel(CoherePreTrainedModel):
|
| 382 |
+
def __init__(self, config: CohereConfig):
|
| 383 |
+
super().__init__(config)
|
| 384 |
+
self.padding_idx = config.pad_token_id
|
| 385 |
+
self.vocab_size = config.vocab_size
|
| 386 |
+
|
| 387 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 388 |
+
self.layers = nn.ModuleList(
|
| 389 |
+
[CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 390 |
+
)
|
| 391 |
+
self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
|
| 392 |
+
self.rotary_emb = CohereRotaryEmbedding(config=config)
|
| 393 |
+
self.gradient_checkpointing = False
|
| 394 |
+
|
| 395 |
+
# Initialize weights and apply final processing
|
| 396 |
+
self.post_init()
|
| 397 |
+
|
| 398 |
+
@merge_with_config_defaults
|
| 399 |
+
@capture_outputs
|
| 400 |
+
@auto_docstring
|
| 401 |
+
def forward(
|
| 402 |
+
self,
|
| 403 |
+
input_ids: torch.LongTensor | None = None,
|
| 404 |
+
attention_mask: torch.Tensor | None = None,
|
| 405 |
+
position_ids: torch.LongTensor | None = None,
|
| 406 |
+
past_key_values: Cache | None = None,
|
| 407 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 408 |
+
use_cache: bool | None = None,
|
| 409 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 410 |
+
) -> BaseModelOutputWithPast:
|
| 411 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 412 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 413 |
+
|
| 414 |
+
if inputs_embeds is None:
|
| 415 |
+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
| 416 |
+
|
| 417 |
+
if use_cache and past_key_values is None:
|
| 418 |
+
past_key_values = DynamicCache(config=self.config)
|
| 419 |
+
|
| 420 |
+
if position_ids is None:
|
| 421 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 422 |
+
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
| 423 |
+
position_ids = position_ids.unsqueeze(0)
|
| 424 |
+
|
| 425 |
+
causal_mask = create_causal_mask(
|
| 426 |
+
config=self.config,
|
| 427 |
+
inputs_embeds=inputs_embeds,
|
| 428 |
+
attention_mask=attention_mask,
|
| 429 |
+
past_key_values=past_key_values,
|
| 430 |
+
position_ids=position_ids,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
hidden_states = inputs_embeds
|
| 434 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
|
| 435 |
+
|
| 436 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 437 |
+
hidden_states = decoder_layer(
|
| 438 |
+
hidden_states,
|
| 439 |
+
attention_mask=causal_mask,
|
| 440 |
+
position_embeddings=position_embeddings,
|
| 441 |
+
position_ids=position_ids,
|
| 442 |
+
past_key_values=past_key_values,
|
| 443 |
+
use_cache=use_cache,
|
| 444 |
+
**kwargs,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
hidden_states = self.norm(hidden_states)
|
| 448 |
+
return BaseModelOutputWithPast(
|
| 449 |
+
last_hidden_state=hidden_states,
|
| 450 |
+
past_key_values=past_key_values,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
@auto_docstring
|
| 455 |
+
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
|
| 456 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 457 |
+
_tp_plan = {"lm_head": "colwise_gather_output"}
|
| 458 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 459 |
+
|
| 460 |
+
def __init__(self, config):
|
| 461 |
+
super().__init__(config)
|
| 462 |
+
self.model = CohereModel(config)
|
| 463 |
+
self.vocab_size = config.vocab_size
|
| 464 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 465 |
+
self.logit_scale = config.logit_scale
|
| 466 |
+
self.tie_word_embeddings = config.tie_word_embeddings
|
| 467 |
+
|
| 468 |
+
# Initialize weights and apply final processing
|
| 469 |
+
self.post_init()
|
| 470 |
+
|
| 471 |
+
@can_return_tuple
|
| 472 |
+
@auto_docstring
|
| 473 |
+
def forward(
|
| 474 |
+
self,
|
| 475 |
+
input_ids: torch.LongTensor | None = None,
|
| 476 |
+
attention_mask: torch.Tensor | None = None,
|
| 477 |
+
position_ids: torch.LongTensor | None = None,
|
| 478 |
+
past_key_values: Cache | None = None,
|
| 479 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 480 |
+
labels: torch.LongTensor | None = None,
|
| 481 |
+
use_cache: bool | None = None,
|
| 482 |
+
logits_to_keep: int | torch.Tensor = 0,
|
| 483 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 484 |
+
) -> CausalLMOutputWithPast:
|
| 485 |
+
r"""
|
| 486 |
+
Example:
|
| 487 |
+
|
| 488 |
+
```python
|
| 489 |
+
>> from transformers import AutoTokenizer, CohereForCausalLM
|
| 490 |
+
|
| 491 |
+
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
| 492 |
+
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
| 493 |
+
|
| 494 |
+
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 495 |
+
>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 496 |
+
|
| 497 |
+
>> # Generate
|
| 498 |
+
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 499 |
+
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 500 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 501 |
+
```"""
|
| 502 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 503 |
+
input_ids=input_ids,
|
| 504 |
+
attention_mask=attention_mask,
|
| 505 |
+
position_ids=position_ids,
|
| 506 |
+
past_key_values=past_key_values,
|
| 507 |
+
inputs_embeds=inputs_embeds,
|
| 508 |
+
use_cache=use_cache,
|
| 509 |
+
**kwargs,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
hidden_states = outputs.last_hidden_state
|
| 513 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 514 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 515 |
+
logits = logits * self.logit_scale # main diff from Llama
|
| 516 |
+
|
| 517 |
+
loss = None
|
| 518 |
+
if labels is not None:
|
| 519 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 520 |
+
|
| 521 |
+
return CausalLMOutputWithPast(
|
| 522 |
+
loss=loss,
|
| 523 |
+
logits=logits,
|
| 524 |
+
past_key_values=outputs.past_key_values,
|
| 525 |
+
hidden_states=outputs.hidden_states,
|
| 526 |
+
attentions=outputs.attentions,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
__all__ = ["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_granitemoeshared import *
|
| 22 |
+
from .modeling_granitemoeshared import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 4 |
+
# and OPT implementations in this library. It has been modified from its
|
| 5 |
+
# original forms to accommodate minor architectural differences compared
|
| 6 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
"""GraniteMoeShared model configuration"""
|
| 20 |
+
|
| 21 |
+
from huggingface_hub.dataclasses import strict
|
| 22 |
+
|
| 23 |
+
from ...configuration_utils import PreTrainedConfig
|
| 24 |
+
from ...modeling_rope_utils import RopeParameters
|
| 25 |
+
from ...utils import auto_docstring
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@auto_docstring(checkpoint="ibm-granite/granite-speech-3.2-8b")
|
| 29 |
+
@strict
|
| 30 |
+
class GraniteMoeSharedConfig(PreTrainedConfig):
|
| 31 |
+
r"""
|
| 32 |
+
embedding_multiplier (`float`, *optional*, defaults to 1.0):
|
| 33 |
+
embedding multiplier
|
| 34 |
+
logits_scaling (`float`, *optional*, defaults to 1.0):
|
| 35 |
+
divisor for output logits
|
| 36 |
+
residual_multiplier (`float`, *optional*, defaults to 1.0):
|
| 37 |
+
residual multiplier
|
| 38 |
+
attention_multiplier (`float`, *optional*, defaults to 1.0):
|
| 39 |
+
attention multiplier
|
| 40 |
+
shared_intermediate_size (`int`, *optional*, defaults to 1024):
|
| 41 |
+
intermediate size for shared experts.
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
>>> from transformers import GraniteMoeSharedModel, GraniteMoeSharedConfig
|
| 45 |
+
|
| 46 |
+
>>> # Initializing a GraniteMoeShared granitemoe-3b style configuration
|
| 47 |
+
>>> configuration = GraniteMoeSharedConfig()
|
| 48 |
+
|
| 49 |
+
>>> # Initializing a model from the granitemoe-7b style configuration
|
| 50 |
+
>>> model = GraniteMoeSharedModel(configuration)
|
| 51 |
+
|
| 52 |
+
>>> # Accessing the model configuration
|
| 53 |
+
>>> configuration = model.config
|
| 54 |
+
```
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
model_type = "granitemoeshared"
|
| 58 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 59 |
+
|
| 60 |
+
vocab_size: int = 32000
|
| 61 |
+
hidden_size: int = 4096
|
| 62 |
+
intermediate_size: int = 11008
|
| 63 |
+
num_hidden_layers: int = 32
|
| 64 |
+
num_attention_heads: int = 32
|
| 65 |
+
num_key_value_heads: int | None = None
|
| 66 |
+
hidden_act: str = "silu"
|
| 67 |
+
max_position_embeddings: int = 2048
|
| 68 |
+
initializer_range: float = 0.02
|
| 69 |
+
rms_norm_eps: float = 1e-6
|
| 70 |
+
use_cache: bool = True
|
| 71 |
+
pad_token_id: int | None = None
|
| 72 |
+
bos_token_id: int | None = 1
|
| 73 |
+
eos_token_id: int | list[int] | None = 2
|
| 74 |
+
tie_word_embeddings: bool = False
|
| 75 |
+
rope_parameters: RopeParameters | dict | None = None
|
| 76 |
+
attention_bias: bool = False
|
| 77 |
+
attention_dropout: float | int | None = 0.0
|
| 78 |
+
embedding_multiplier: float | int | None = 1.0
|
| 79 |
+
logits_scaling: float | int | None = 1.0
|
| 80 |
+
residual_multiplier: float | int | None = 1.0
|
| 81 |
+
attention_multiplier: float | int | None = 1.0
|
| 82 |
+
num_local_experts: int | None = 8
|
| 83 |
+
num_experts_per_tok: int | None = 2
|
| 84 |
+
output_router_logits: bool | None = False
|
| 85 |
+
router_aux_loss_coef: float | None = 0.001
|
| 86 |
+
shared_intermediate_size: int = 0
|
| 87 |
+
|
| 88 |
+
def __post_init__(self, **kwargs):
|
| 89 |
+
if self.num_key_value_heads is None:
|
| 90 |
+
self.num_key_value_heads = self.num_attention_heads
|
| 91 |
+
|
| 92 |
+
super().__post_init__(**kwargs)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
__all__ = ["GraniteMoeSharedConfig"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py
ADDED
|
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/granitemoeshared/modular_granitemoeshared.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_granitemoeshared.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
from collections.abc import Callable
|
| 22 |
+
from typing import Optional, TypedDict
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import functional as F
|
| 27 |
+
|
| 28 |
+
from ... import initialization as init
|
| 29 |
+
from ...activations import ACT2FN
|
| 30 |
+
from ...cache_utils import Cache, DynamicCache
|
| 31 |
+
from ...generation import GenerationMixin
|
| 32 |
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
| 33 |
+
from ...masking_utils import create_causal_mask
|
| 34 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 35 |
+
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
| 36 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 37 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 38 |
+
from ...processing_utils import Unpack
|
| 39 |
+
from ...utils import TransformersKwargs, auto_docstring
|
| 40 |
+
from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
|
| 41 |
+
from ...utils.output_capturing import capture_outputs
|
| 42 |
+
from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class GraniteFlashAttentionKwargs(TypedDict, total=False):
|
| 46 |
+
"""
|
| 47 |
+
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
| 48 |
+
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
| 49 |
+
|
| 50 |
+
cu_seq_lens_q (`torch.LongTensor`):
|
| 51 |
+
Gets cumulative sequence length for query state.
|
| 52 |
+
cu_seq_lens_k (`torch.LongTensor`):
|
| 53 |
+
Gets cumulative sequence length for key state.
|
| 54 |
+
max_length_q (`int`):
|
| 55 |
+
Maximum sequence length for query state.
|
| 56 |
+
max_length_k (`int`):
|
| 57 |
+
Maximum sequence length for key state.
|
| 58 |
+
seq_idx (`torch.IntTensor):
|
| 59 |
+
Index of each packed sequence.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
cu_seq_lens_q: torch.LongTensor
|
| 63 |
+
cu_seq_lens_k: torch.LongTensor
|
| 64 |
+
max_length_q: int
|
| 65 |
+
max_length_k: int
|
| 66 |
+
seq_idx: torch.IntTensor
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class GraniteMoeSharedMLP(nn.Module):
|
| 70 |
+
"""
|
| 71 |
+
MLP layer for shared experts
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
config:
|
| 75 |
+
Configuration object with model hyperparameters.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, config: GraniteMoeSharedConfig):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
self.input_size = config.hidden_size
|
| 82 |
+
self.hidden_size = config.shared_intermediate_size
|
| 83 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 84 |
+
self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False)
|
| 85 |
+
self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False)
|
| 86 |
+
|
| 87 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
hidden_states = self.input_linear(hidden_states)
|
| 89 |
+
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
|
| 90 |
+
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
|
| 91 |
+
hidden_states = self.output_linear(hidden_states)
|
| 92 |
+
return hidden_states
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 96 |
+
class GraniteMoeSharedRMSNorm(nn.Module):
|
| 97 |
+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
|
| 98 |
+
"""
|
| 99 |
+
GraniteMoeSharedRMSNorm is equivalent to T5LayerNorm
|
| 100 |
+
"""
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 103 |
+
self.variance_epsilon = eps
|
| 104 |
+
|
| 105 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
input_dtype = hidden_states.dtype
|
| 107 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 108 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 109 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 110 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 111 |
+
|
| 112 |
+
def extra_repr(self):
|
| 113 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class GraniteMoeSharedParallelExperts(nn.Module):
|
| 117 |
+
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
|
| 118 |
+
"""
|
| 119 |
+
Initialize the GraniteMoeSharedParallelExperts module.
|
| 120 |
+
The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
|
| 121 |
+
many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
|
| 122 |
+
[ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
|
| 123 |
+
[MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
|
| 124 |
+
used in vllm.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
num_experts (int):
|
| 128 |
+
Number of experts.
|
| 129 |
+
input_size (int):
|
| 130 |
+
Size of the input.
|
| 131 |
+
output_size (int):
|
| 132 |
+
Size of the output.
|
| 133 |
+
"""
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
|
| 136 |
+
self.num_experts = num_experts
|
| 137 |
+
self.input_size = input_size
|
| 138 |
+
self.output_size = output_size
|
| 139 |
+
|
| 140 |
+
def forward(self, inputs, expert_size):
|
| 141 |
+
"""
|
| 142 |
+
Forward pass of the GraniteMoeSharedParallelExperts module.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
inputs (Tensor):
|
| 146 |
+
Input tensor.
|
| 147 |
+
expert_size:
|
| 148 |
+
Expert size information.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Tensor: Output tensor.
|
| 152 |
+
"""
|
| 153 |
+
input_list = inputs.split(expert_size, dim=0)
|
| 154 |
+
output_list = []
|
| 155 |
+
for i in range(self.num_experts):
|
| 156 |
+
output_list.append(F.linear(input_list[i], self.weight[i]))
|
| 157 |
+
results = torch.cat(output_list, dim=0)
|
| 158 |
+
return results
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class GraniteMoeSharedTopKGating(nn.Module):
|
| 162 |
+
def __init__(self, input_size: int, num_experts: int, top_k: int):
|
| 163 |
+
"""
|
| 164 |
+
Initialize the top-k gating mechanism.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
input_size (`int`):
|
| 168 |
+
Size of the input.
|
| 169 |
+
num_experts (`int`):
|
| 170 |
+
Number of experts.
|
| 171 |
+
top_k (`int`):
|
| 172 |
+
Number of top experts to select.
|
| 173 |
+
"""
|
| 174 |
+
super().__init__()
|
| 175 |
+
|
| 176 |
+
self.num_experts = num_experts
|
| 177 |
+
self.input_size = input_size
|
| 178 |
+
self.top_k = top_k
|
| 179 |
+
|
| 180 |
+
self.layer = nn.Linear(input_size, num_experts, bias=False)
|
| 181 |
+
|
| 182 |
+
def forward(self, hidden_states):
|
| 183 |
+
# compute the top_k routing decision
|
| 184 |
+
logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
|
| 185 |
+
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
|
| 186 |
+
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
|
| 187 |
+
|
| 188 |
+
# compute number of input given to each expert
|
| 189 |
+
zeros = torch.zeros(
|
| 190 |
+
[top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
|
| 191 |
+
) # [num_tokens, num_experts]
|
| 192 |
+
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
|
| 193 |
+
expert_size = gates.long().sum(0) # [num_experts,]
|
| 194 |
+
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
|
| 195 |
+
# (and `DataDependentOutputException`)
|
| 196 |
+
expert_size = expert_size.tolist()
|
| 197 |
+
|
| 198 |
+
# sort and group input tokens according to expert assignment
|
| 199 |
+
top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
|
| 200 |
+
_, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
|
| 201 |
+
batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
|
| 202 |
+
|
| 203 |
+
# gather the gate values for grouped input tokens
|
| 204 |
+
top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
|
| 205 |
+
batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
|
| 206 |
+
|
| 207 |
+
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class GraniteMoeSharedMoE(nn.Module):
|
| 211 |
+
"""
|
| 212 |
+
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
config:
|
| 216 |
+
Configuration object with model hyperparameters.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
def __init__(self, config: GraniteMoeSharedConfig):
|
| 220 |
+
super().__init__()
|
| 221 |
+
|
| 222 |
+
self.input_size = config.hidden_size
|
| 223 |
+
self.hidden_size = config.intermediate_size
|
| 224 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 225 |
+
self.input_linear = GraniteMoeSharedParallelExperts(
|
| 226 |
+
config.num_local_experts, self.input_size, self.hidden_size * 2
|
| 227 |
+
)
|
| 228 |
+
self.output_linear = GraniteMoeSharedParallelExperts(
|
| 229 |
+
config.num_local_experts, self.hidden_size, self.input_size
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.router = GraniteMoeSharedTopKGating(
|
| 233 |
+
input_size=self.input_size,
|
| 234 |
+
num_experts=config.num_local_experts,
|
| 235 |
+
top_k=config.num_experts_per_tok,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def forward(self, layer_input):
|
| 239 |
+
bsz, length, emb_size = layer_input.size()
|
| 240 |
+
layer_input = layer_input.reshape(-1, emb_size)
|
| 241 |
+
_, batch_index, batch_gates, expert_size, _ = self.router(layer_input)
|
| 242 |
+
|
| 243 |
+
expert_inputs = layer_input[batch_index]
|
| 244 |
+
hidden_states = self.input_linear(expert_inputs, expert_size)
|
| 245 |
+
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
|
| 246 |
+
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
|
| 247 |
+
expert_outputs = self.output_linear(hidden_states, expert_size)
|
| 248 |
+
|
| 249 |
+
expert_outputs = expert_outputs * batch_gates[:, None]
|
| 250 |
+
|
| 251 |
+
zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
|
| 252 |
+
layer_output = zeros.index_add(0, batch_index, expert_outputs)
|
| 253 |
+
layer_output = layer_output.view(bsz, length, self.input_size)
|
| 254 |
+
return layer_output
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def rotate_half(x):
|
| 258 |
+
"""Rotates half the hidden dims of the input."""
|
| 259 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 260 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 261 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@use_kernel_func_from_hub("rotary_pos_emb")
|
| 265 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 266 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
q (`torch.Tensor`): The query tensor.
|
| 270 |
+
k (`torch.Tensor`): The key tensor.
|
| 271 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 272 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 273 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 274 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 275 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 276 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 277 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 278 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 279 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 280 |
+
Returns:
|
| 281 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 282 |
+
"""
|
| 283 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 284 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 285 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 286 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 287 |
+
return q_embed, k_embed
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 291 |
+
"""
|
| 292 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 293 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 294 |
+
"""
|
| 295 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 296 |
+
if n_rep == 1:
|
| 297 |
+
return hidden_states
|
| 298 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 299 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def eager_attention_forward(
|
| 303 |
+
module: nn.Module,
|
| 304 |
+
query: torch.Tensor,
|
| 305 |
+
key: torch.Tensor,
|
| 306 |
+
value: torch.Tensor,
|
| 307 |
+
attention_mask: torch.Tensor | None,
|
| 308 |
+
scaling: float,
|
| 309 |
+
dropout: float = 0.0,
|
| 310 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 311 |
+
):
|
| 312 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 313 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 314 |
+
|
| 315 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 316 |
+
if attention_mask is not None:
|
| 317 |
+
attn_weights = attn_weights + attention_mask
|
| 318 |
+
|
| 319 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 320 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 321 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 322 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 323 |
+
|
| 324 |
+
return attn_output, attn_weights
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
@use_kernelized_func(apply_rotary_pos_emb)
|
| 328 |
+
class GraniteMoeSharedAttention(nn.Module):
|
| 329 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 330 |
+
|
| 331 |
+
def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.config = config
|
| 334 |
+
self.layer_idx = layer_idx
|
| 335 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 336 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 337 |
+
self.scaling = config.attention_multiplier # Only diff with llama
|
| 338 |
+
self.attention_dropout = config.attention_dropout
|
| 339 |
+
self.is_causal = True
|
| 340 |
+
|
| 341 |
+
self.q_proj = nn.Linear(
|
| 342 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 343 |
+
)
|
| 344 |
+
self.k_proj = nn.Linear(
|
| 345 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 346 |
+
)
|
| 347 |
+
self.v_proj = nn.Linear(
|
| 348 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 349 |
+
)
|
| 350 |
+
self.o_proj = nn.Linear(
|
| 351 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
def forward(
|
| 355 |
+
self,
|
| 356 |
+
hidden_states: torch.Tensor,
|
| 357 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 358 |
+
attention_mask: torch.Tensor | None = None,
|
| 359 |
+
past_key_values: Cache | None = None,
|
| 360 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 361 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 362 |
+
input_shape = hidden_states.shape[:-1]
|
| 363 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 364 |
+
|
| 365 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 366 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 367 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 368 |
+
|
| 369 |
+
cos, sin = position_embeddings
|
| 370 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 371 |
+
|
| 372 |
+
if past_key_values is not None:
|
| 373 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
| 374 |
+
|
| 375 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 376 |
+
self.config._attn_implementation, eager_attention_forward
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
attn_output, attn_weights = attention_interface(
|
| 380 |
+
self,
|
| 381 |
+
query_states,
|
| 382 |
+
key_states,
|
| 383 |
+
value_states,
|
| 384 |
+
attention_mask,
|
| 385 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 386 |
+
scaling=self.scaling,
|
| 387 |
+
**kwargs,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 391 |
+
attn_output = self.o_proj(attn_output)
|
| 392 |
+
return attn_output, attn_weights
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
|
| 396 |
+
def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
|
| 397 |
+
super().__init__()
|
| 398 |
+
self.hidden_size = config.hidden_size
|
| 399 |
+
self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx)
|
| 400 |
+
self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 401 |
+
self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 402 |
+
self.block_sparse_moe = GraniteMoeSharedMoE(config)
|
| 403 |
+
self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
|
| 404 |
+
self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config)
|
| 405 |
+
|
| 406 |
+
def forward(
|
| 407 |
+
self,
|
| 408 |
+
hidden_states: torch.Tensor,
|
| 409 |
+
attention_mask: torch.Tensor | None = None,
|
| 410 |
+
position_ids: torch.LongTensor | None = None,
|
| 411 |
+
past_key_values: Cache | None = None,
|
| 412 |
+
output_attentions: bool | None = False,
|
| 413 |
+
use_cache: bool | None = False,
|
| 414 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 415 |
+
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
| 416 |
+
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
| 417 |
+
residual = hidden_states
|
| 418 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 419 |
+
|
| 420 |
+
# Self Attention
|
| 421 |
+
hidden_states, _ = self.self_attn(
|
| 422 |
+
hidden_states=hidden_states,
|
| 423 |
+
attention_mask=attention_mask,
|
| 424 |
+
position_ids=position_ids,
|
| 425 |
+
past_key_values=past_key_values,
|
| 426 |
+
output_attentions=output_attentions,
|
| 427 |
+
use_cache=use_cache,
|
| 428 |
+
position_embeddings=position_embeddings,
|
| 429 |
+
**kwargs,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
| 433 |
+
|
| 434 |
+
residual = hidden_states
|
| 435 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 436 |
+
moe_hidden_states = self.block_sparse_moe(hidden_states)
|
| 437 |
+
|
| 438 |
+
if self.shared_mlp is None:
|
| 439 |
+
hidden_states = moe_hidden_states
|
| 440 |
+
else:
|
| 441 |
+
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
|
| 442 |
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
| 443 |
+
return hidden_states
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
@auto_docstring
|
| 447 |
+
class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
|
| 448 |
+
config: GraniteMoeSharedConfig
|
| 449 |
+
base_model_prefix = "model"
|
| 450 |
+
supports_gradient_checkpointing = True
|
| 451 |
+
_no_split_modules = ["GraniteMoeSharedDecoderLayer"]
|
| 452 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 453 |
+
_supports_flash_attn = True
|
| 454 |
+
_supports_sdpa = True
|
| 455 |
+
_supports_flex_attn = True
|
| 456 |
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
| 457 |
+
_supports_attention_backend = True
|
| 458 |
+
_can_record_outputs = {
|
| 459 |
+
"hidden_states": GraniteMoeSharedDecoderLayer,
|
| 460 |
+
"attentions": GraniteMoeSharedAttention,
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
@torch.no_grad()
|
| 464 |
+
def _init_weights(self, module):
|
| 465 |
+
super()._init_weights(module)
|
| 466 |
+
if isinstance(module, GraniteMoeSharedParallelExperts):
|
| 467 |
+
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class GraniteMoeSharedRotaryEmbedding(nn.Module):
|
| 471 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 472 |
+
|
| 473 |
+
def __init__(self, config: GraniteMoeSharedConfig, device=None):
|
| 474 |
+
super().__init__()
|
| 475 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 476 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 477 |
+
|
| 478 |
+
self.config = config
|
| 479 |
+
|
| 480 |
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
| 481 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 482 |
+
if self.rope_type != "default":
|
| 483 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 484 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 485 |
+
|
| 486 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 487 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
| 488 |
+
|
| 489 |
+
@staticmethod
|
| 490 |
+
def compute_default_rope_parameters(
|
| 491 |
+
config: GraniteMoeSharedConfig | None = None,
|
| 492 |
+
device: Optional["torch.device"] = None,
|
| 493 |
+
seq_len: int | None = None,
|
| 494 |
+
) -> tuple["torch.Tensor", float]:
|
| 495 |
+
"""
|
| 496 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 497 |
+
Args:
|
| 498 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 499 |
+
The model configuration.
|
| 500 |
+
device (`torch.device`):
|
| 501 |
+
The device to use for initialization of the inverse frequencies.
|
| 502 |
+
seq_len (`int`, *optional*):
|
| 503 |
+
The current sequence length. Unused for this type of RoPE.
|
| 504 |
+
Returns:
|
| 505 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 506 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 507 |
+
"""
|
| 508 |
+
base = config.rope_parameters["rope_theta"]
|
| 509 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 510 |
+
|
| 511 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 512 |
+
|
| 513 |
+
# Compute the inverse frequencies
|
| 514 |
+
inv_freq = 1.0 / (
|
| 515 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 516 |
+
)
|
| 517 |
+
return inv_freq, attention_factor
|
| 518 |
+
|
| 519 |
+
@torch.no_grad()
|
| 520 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 521 |
+
def forward(self, x, position_ids):
|
| 522 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 523 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 524 |
+
|
| 525 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 526 |
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 527 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 528 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 529 |
+
cos = emb.cos() * self.attention_scaling
|
| 530 |
+
sin = emb.sin() * self.attention_scaling
|
| 531 |
+
|
| 532 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
@auto_docstring
|
| 536 |
+
class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
|
| 537 |
+
def __init__(self, config: GraniteMoeSharedConfig):
|
| 538 |
+
super().__init__(config)
|
| 539 |
+
self.padding_idx = config.pad_token_id
|
| 540 |
+
self.vocab_size = config.vocab_size
|
| 541 |
+
|
| 542 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 543 |
+
self.layers = nn.ModuleList(
|
| 544 |
+
[GraniteMoeSharedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 545 |
+
)
|
| 546 |
+
self.norm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 547 |
+
self.rotary_emb = GraniteMoeSharedRotaryEmbedding(config=config)
|
| 548 |
+
self.gradient_checkpointing = False
|
| 549 |
+
self.embedding_multiplier = config.embedding_multiplier
|
| 550 |
+
|
| 551 |
+
# Initialize weights and apply final processing
|
| 552 |
+
self.post_init()
|
| 553 |
+
|
| 554 |
+
@merge_with_config_defaults
|
| 555 |
+
@capture_outputs
|
| 556 |
+
@auto_docstring
|
| 557 |
+
def forward(
|
| 558 |
+
self,
|
| 559 |
+
input_ids: torch.LongTensor | None = None,
|
| 560 |
+
attention_mask: torch.Tensor | None = None,
|
| 561 |
+
position_ids: torch.LongTensor | None = None,
|
| 562 |
+
past_key_values: Cache | None = None,
|
| 563 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 564 |
+
use_cache: bool | None = None,
|
| 565 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 566 |
+
) -> MoeModelOutputWithPast:
|
| 567 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 568 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 569 |
+
|
| 570 |
+
if use_cache and past_key_values is None:
|
| 571 |
+
past_key_values = DynamicCache(config=self.config)
|
| 572 |
+
|
| 573 |
+
if inputs_embeds is None:
|
| 574 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 575 |
+
|
| 576 |
+
if position_ids is None:
|
| 577 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 578 |
+
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
| 579 |
+
position_ids = position_ids.unsqueeze(0)
|
| 580 |
+
|
| 581 |
+
causal_mask = create_causal_mask( # ONLY DIFF WITH MIXTRAL: NO SLIDING
|
| 582 |
+
config=self.config,
|
| 583 |
+
inputs_embeds=inputs_embeds,
|
| 584 |
+
attention_mask=attention_mask,
|
| 585 |
+
past_key_values=past_key_values,
|
| 586 |
+
position_ids=position_ids,
|
| 587 |
+
)
|
| 588 |
+
inputs_embeds = inputs_embeds * self.embedding_multiplier
|
| 589 |
+
hidden_states = inputs_embeds
|
| 590 |
+
|
| 591 |
+
# create position embeddings to be shared across the decoder layers
|
| 592 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 593 |
+
|
| 594 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 595 |
+
hidden_states = decoder_layer(
|
| 596 |
+
hidden_states,
|
| 597 |
+
position_embeddings=position_embeddings,
|
| 598 |
+
attention_mask=causal_mask,
|
| 599 |
+
position_ids=position_ids,
|
| 600 |
+
past_key_values=past_key_values,
|
| 601 |
+
use_cache=use_cache,
|
| 602 |
+
**kwargs,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
hidden_states = self.norm(hidden_states)
|
| 606 |
+
|
| 607 |
+
return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
|
| 608 |
+
last_hidden_state=hidden_states,
|
| 609 |
+
past_key_values=past_key_values,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def load_balancing_loss_func(
|
| 614 |
+
gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
|
| 615 |
+
num_experts: int | None = None,
|
| 616 |
+
top_k=2,
|
| 617 |
+
attention_mask: torch.Tensor | None = None,
|
| 618 |
+
) -> torch.Tensor | int:
|
| 619 |
+
r"""
|
| 620 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 621 |
+
|
| 622 |
+
See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
|
| 623 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 624 |
+
experts is too unbalanced.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
gate_logits:
|
| 628 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
| 629 |
+
shape [batch_size X sequence_length, num_experts].
|
| 630 |
+
num_experts:
|
| 631 |
+
Number of experts
|
| 632 |
+
top_k:
|
| 633 |
+
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
| 634 |
+
parameter.
|
| 635 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 636 |
+
The attention_mask used in forward function
|
| 637 |
+
shape [batch_size X sequence_length] if not None.
|
| 638 |
+
|
| 639 |
+
Returns:
|
| 640 |
+
The auxiliary loss.
|
| 641 |
+
"""
|
| 642 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
| 643 |
+
return 0
|
| 644 |
+
|
| 645 |
+
if isinstance(gate_logits, tuple):
|
| 646 |
+
compute_device = gate_logits[0].device
|
| 647 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
| 648 |
+
|
| 649 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
| 650 |
+
|
| 651 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 652 |
+
|
| 653 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 654 |
+
|
| 655 |
+
if attention_mask is None:
|
| 656 |
+
# Compute the percentage of tokens routed to each experts
|
| 657 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 658 |
+
|
| 659 |
+
# Compute the average probability of routing to these experts
|
| 660 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 661 |
+
else:
|
| 662 |
+
batch_size, sequence_length = attention_mask.shape
|
| 663 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
| 664 |
+
|
| 665 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 666 |
+
expert_attention_mask = (
|
| 667 |
+
attention_mask[None, :, :, None, None]
|
| 668 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
| 669 |
+
.reshape(-1, top_k, num_experts)
|
| 670 |
+
.to(compute_device)
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
# Compute the percentage of tokens routed to each experts
|
| 674 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
| 675 |
+
expert_attention_mask, dim=0
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 679 |
+
router_per_expert_attention_mask = (
|
| 680 |
+
attention_mask[None, :, :, None]
|
| 681 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 682 |
+
.reshape(-1, num_experts)
|
| 683 |
+
.to(compute_device)
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# Compute the average probability of routing to these experts
|
| 687 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
| 688 |
+
router_per_expert_attention_mask, dim=0
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 692 |
+
return overall_loss * num_experts
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
@auto_docstring
|
| 696 |
+
class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMixin):
|
| 697 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 698 |
+
_tp_plan = {"lm_head": "colwise_gather_output"}
|
| 699 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 700 |
+
|
| 701 |
+
def __init__(self, config: GraniteMoeSharedConfig):
|
| 702 |
+
super().__init__(config)
|
| 703 |
+
self.model = GraniteMoeSharedModel(config)
|
| 704 |
+
self.vocab_size = config.vocab_size
|
| 705 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 706 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 707 |
+
self.num_experts = config.num_local_experts
|
| 708 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 709 |
+
self.logits_scaling = config.logits_scaling
|
| 710 |
+
|
| 711 |
+
# Initialize weights and apply final processing
|
| 712 |
+
self.post_init()
|
| 713 |
+
|
| 714 |
+
@auto_docstring
|
| 715 |
+
@can_return_tuple
|
| 716 |
+
def forward(
|
| 717 |
+
self,
|
| 718 |
+
input_ids: torch.LongTensor | None = None,
|
| 719 |
+
attention_mask: torch.Tensor | None = None,
|
| 720 |
+
position_ids: torch.LongTensor | None = None,
|
| 721 |
+
past_key_values: Cache | None = None,
|
| 722 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 723 |
+
labels: torch.LongTensor | None = None,
|
| 724 |
+
output_router_logits: bool | None = None,
|
| 725 |
+
logits_to_keep: int | torch.Tensor = 0,
|
| 726 |
+
**kwargs,
|
| 727 |
+
) -> tuple | MoeCausalLMOutputWithPast:
|
| 728 |
+
r"""
|
| 729 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 730 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 731 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 732 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 733 |
+
|
| 734 |
+
Example:
|
| 735 |
+
|
| 736 |
+
```python
|
| 737 |
+
>>> from transformers import AutoTokenizer, GraniteMoeSharedForCausalLM
|
| 738 |
+
|
| 739 |
+
>>> model = GraniteMoeSharedForCausalLM.from_pretrained("ibm/PowerMoE-3b")
|
| 740 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b")
|
| 741 |
+
|
| 742 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 743 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 744 |
+
|
| 745 |
+
>>> # Generate
|
| 746 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 747 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 748 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 749 |
+
```"""
|
| 750 |
+
output_router_logits = (
|
| 751 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 752 |
+
)
|
| 753 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 754 |
+
outputs = self.model(
|
| 755 |
+
input_ids=input_ids,
|
| 756 |
+
attention_mask=attention_mask,
|
| 757 |
+
position_ids=position_ids,
|
| 758 |
+
past_key_values=past_key_values,
|
| 759 |
+
inputs_embeds=inputs_embeds,
|
| 760 |
+
**kwargs,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
# Only compute necessary logits
|
| 764 |
+
hidden_states = outputs.last_hidden_state
|
| 765 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 766 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 767 |
+
logits = logits / self.config.logits_scaling
|
| 768 |
+
|
| 769 |
+
loss = None
|
| 770 |
+
if labels is not None:
|
| 771 |
+
# Flatten the tokens
|
| 772 |
+
loss = self.loss_function(
|
| 773 |
+
logits,
|
| 774 |
+
labels,
|
| 775 |
+
vocab_size=self.config.vocab_size,
|
| 776 |
+
**kwargs,
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
aux_loss = None
|
| 780 |
+
if output_router_logits:
|
| 781 |
+
aux_loss = load_balancing_loss_func(
|
| 782 |
+
outputs.router_logits,
|
| 783 |
+
self.num_experts,
|
| 784 |
+
self.num_experts_per_tok,
|
| 785 |
+
attention_mask,
|
| 786 |
+
)
|
| 787 |
+
if labels is not None:
|
| 788 |
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
| 789 |
+
return MoeCausalLMOutputWithPast(
|
| 790 |
+
loss=loss,
|
| 791 |
+
aux_loss=aux_loss,
|
| 792 |
+
logits=logits,
|
| 793 |
+
past_key_values=outputs.past_key_values,
|
| 794 |
+
hidden_states=outputs.hidden_states,
|
| 795 |
+
attentions=outputs.attentions,
|
| 796 |
+
router_logits=outputs.router_logits,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
__all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from typing import TypedDict
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
from ...activations import ACT2FN
|
| 21 |
+
from ...cache_utils import Cache
|
| 22 |
+
from ...processing_utils import Unpack
|
| 23 |
+
from ...utils import logging
|
| 24 |
+
from ..granitemoe.modeling_granitemoe import (
|
| 25 |
+
GraniteMoeDecoderLayer,
|
| 26 |
+
GraniteMoeForCausalLM,
|
| 27 |
+
GraniteMoeModel,
|
| 28 |
+
GraniteMoePreTrainedModel,
|
| 29 |
+
)
|
| 30 |
+
from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class GraniteFlashAttentionKwargs(TypedDict, total=False):
|
| 37 |
+
"""
|
| 38 |
+
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
| 39 |
+
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
| 40 |
+
|
| 41 |
+
cu_seq_lens_q (`torch.LongTensor`):
|
| 42 |
+
Gets cumulative sequence length for query state.
|
| 43 |
+
cu_seq_lens_k (`torch.LongTensor`):
|
| 44 |
+
Gets cumulative sequence length for key state.
|
| 45 |
+
max_length_q (`int`):
|
| 46 |
+
Maximum sequence length for query state.
|
| 47 |
+
max_length_k (`int`):
|
| 48 |
+
Maximum sequence length for key state.
|
| 49 |
+
seq_idx (`torch.IntTensor):
|
| 50 |
+
Index of each packed sequence.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
cu_seq_lens_q: torch.LongTensor
|
| 54 |
+
cu_seq_lens_k: torch.LongTensor
|
| 55 |
+
max_length_q: int
|
| 56 |
+
max_length_k: int
|
| 57 |
+
seq_idx: torch.IntTensor
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class GraniteMoeSharedMLP(nn.Module):
|
| 61 |
+
"""
|
| 62 |
+
MLP layer for shared experts
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
config:
|
| 66 |
+
Configuration object with model hyperparameters.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, config: GraniteMoeSharedConfig):
|
| 70 |
+
super().__init__()
|
| 71 |
+
|
| 72 |
+
self.input_size = config.hidden_size
|
| 73 |
+
self.hidden_size = config.shared_intermediate_size
|
| 74 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 75 |
+
self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False)
|
| 76 |
+
self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False)
|
| 77 |
+
|
| 78 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
hidden_states = self.input_linear(hidden_states)
|
| 80 |
+
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
|
| 81 |
+
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
|
| 82 |
+
hidden_states = self.output_linear(hidden_states)
|
| 83 |
+
return hidden_states
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
|
| 87 |
+
def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
|
| 88 |
+
super().__init__(config, layer_idx)
|
| 89 |
+
self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config)
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self,
|
| 93 |
+
hidden_states: torch.Tensor,
|
| 94 |
+
attention_mask: torch.Tensor | None = None,
|
| 95 |
+
position_ids: torch.LongTensor | None = None,
|
| 96 |
+
past_key_values: Cache | None = None,
|
| 97 |
+
output_attentions: bool | None = False,
|
| 98 |
+
use_cache: bool | None = False,
|
| 99 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 100 |
+
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
| 101 |
+
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
| 102 |
+
residual = hidden_states
|
| 103 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 104 |
+
|
| 105 |
+
# Self Attention
|
| 106 |
+
hidden_states, _ = self.self_attn(
|
| 107 |
+
hidden_states=hidden_states,
|
| 108 |
+
attention_mask=attention_mask,
|
| 109 |
+
position_ids=position_ids,
|
| 110 |
+
past_key_values=past_key_values,
|
| 111 |
+
output_attentions=output_attentions,
|
| 112 |
+
use_cache=use_cache,
|
| 113 |
+
position_embeddings=position_embeddings,
|
| 114 |
+
**kwargs,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
| 118 |
+
|
| 119 |
+
residual = hidden_states
|
| 120 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 121 |
+
moe_hidden_states = self.block_sparse_moe(hidden_states)
|
| 122 |
+
|
| 123 |
+
if self.shared_mlp is None:
|
| 124 |
+
hidden_states = moe_hidden_states
|
| 125 |
+
else:
|
| 126 |
+
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
|
| 127 |
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
| 128 |
+
return hidden_states
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class GraniteMoeSharedPreTrainedModel(GraniteMoePreTrainedModel):
|
| 132 |
+
config: GraniteMoeSharedConfig
|
| 133 |
+
_no_split_modules = ["GraniteMoeSharedDecoderLayer"]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class GraniteMoeSharedModel(GraniteMoeModel):
|
| 137 |
+
def __init__(self, config: GraniteMoeSharedConfig):
|
| 138 |
+
super().__init__(config)
|
| 139 |
+
self.layers = nn.ModuleList(
|
| 140 |
+
[GraniteMoeSharedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class GraniteMoeSharedForCausalLM(GraniteMoeForCausalLM):
|
| 145 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 146 |
+
|
| 147 |
+
def __init__(self, config: GraniteMoeSharedConfig):
|
| 148 |
+
super().__init__(config)
|
| 149 |
+
self.model = GraniteMoeSharedModel(config)
|
| 150 |
+
# Initialize weights and apply final processing
|
| 151 |
+
self.post_init()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
__all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_instructblip import *
|
| 22 |
+
from .modeling_instructblip import *
|
| 23 |
+
from .processing_instructblip import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/configuration_instructblip.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""InstructBLIP model configuration"""
|
| 15 |
+
|
| 16 |
+
from huggingface_hub.dataclasses import strict
|
| 17 |
+
|
| 18 |
+
from ...configuration_utils import PreTrainedConfig
|
| 19 |
+
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 20 |
+
from ...utils import auto_docstring, logging
|
| 21 |
+
from ..auto import CONFIG_MAPPING, AutoConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
|
| 28 |
+
@strict
|
| 29 |
+
class InstructBlipVisionConfig(PreTrainedConfig):
|
| 30 |
+
r"""
|
| 31 |
+
Example:
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
>>> from transformers import InstructBlipVisionConfig, InstructBlipVisionModel
|
| 35 |
+
|
| 36 |
+
>>> # Initializing a InstructBlipVisionConfig with Salesforce/instructblip-flan-t5-xl style configuration
|
| 37 |
+
>>> configuration = InstructBlipVisionConfig()
|
| 38 |
+
|
| 39 |
+
>>> # Initializing a InstructBlipVisionModel (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
|
| 40 |
+
>>> model = InstructBlipVisionModel(configuration)
|
| 41 |
+
|
| 42 |
+
>>> # Accessing the model configuration
|
| 43 |
+
>>> configuration = model.config
|
| 44 |
+
```"""
|
| 45 |
+
|
| 46 |
+
model_type = "instructblip_vision_model"
|
| 47 |
+
base_config_key = "vision_config"
|
| 48 |
+
|
| 49 |
+
hidden_size: int = 1408
|
| 50 |
+
intermediate_size: int = 6144
|
| 51 |
+
num_hidden_layers: int = 39
|
| 52 |
+
num_attention_heads: int = 16
|
| 53 |
+
image_size: int | list[int] | tuple[int, int] = 224
|
| 54 |
+
patch_size: int | list[int] | tuple[int, int] = 14
|
| 55 |
+
hidden_act: str = "gelu"
|
| 56 |
+
layer_norm_eps: float = 1e-6
|
| 57 |
+
attention_dropout: float | int = 0.0
|
| 58 |
+
initializer_range: float = 1e-10
|
| 59 |
+
qkv_bias: bool = True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
|
| 63 |
+
@strict
|
| 64 |
+
class InstructBlipQFormerConfig(PreTrainedConfig):
|
| 65 |
+
r"""
|
| 66 |
+
cross_attention_frequency (`int`, *optional*, defaults to 2):
|
| 67 |
+
The frequency of adding cross-attention to the Transformer layers.
|
| 68 |
+
encoder_hidden_size (`int`, *optional*, defaults to 1408):
|
| 69 |
+
The hidden size of the hidden states for cross-attention.
|
| 70 |
+
|
| 71 |
+
Examples:
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
>>> from transformers import InstructBlipQFormerConfig, InstructBlipQFormerModel
|
| 75 |
+
|
| 76 |
+
>>> # Initializing a InstructBLIP Salesforce/instructblip-flan-t5-xl style configuration
|
| 77 |
+
>>> configuration = InstructBlipQFormerConfig()
|
| 78 |
+
|
| 79 |
+
>>> # Initializing a model (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
|
| 80 |
+
>>> model = InstructBlipQFormerModel(configuration)
|
| 81 |
+
>>> # Accessing the model configuration
|
| 82 |
+
>>> configuration = model.config
|
| 83 |
+
```"""
|
| 84 |
+
|
| 85 |
+
model_type = "instructblip_qformer"
|
| 86 |
+
base_config_key = "qformer_config"
|
| 87 |
+
|
| 88 |
+
vocab_size: int = 30522
|
| 89 |
+
hidden_size: int = 768
|
| 90 |
+
num_hidden_layers: int = 12
|
| 91 |
+
num_attention_heads: int = 12
|
| 92 |
+
intermediate_size: int = 3072
|
| 93 |
+
hidden_act: str = "gelu"
|
| 94 |
+
hidden_dropout_prob: float | int = 0.1
|
| 95 |
+
attention_probs_dropout_prob: float | int = 0.1
|
| 96 |
+
max_position_embeddings: int = 512
|
| 97 |
+
initializer_range: float = 0.02
|
| 98 |
+
layer_norm_eps: float = 1e-12
|
| 99 |
+
pad_token_id: int | None = 0
|
| 100 |
+
cross_attention_frequency: int = 2
|
| 101 |
+
encoder_hidden_size: int = 1408
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
|
| 105 |
+
@strict
|
| 106 |
+
class InstructBlipConfig(PreTrainedConfig):
|
| 107 |
+
r"""
|
| 108 |
+
qformer_config (`dict`, *optional*):
|
| 109 |
+
Dictionary of configuration options used to initialize [`InstructBlipQFormerConfig`].
|
| 110 |
+
num_query_tokens (`int`, *optional*, defaults to 32):
|
| 111 |
+
The number of query tokens passed through the Transformer.
|
| 112 |
+
|
| 113 |
+
Example:
|
| 114 |
+
|
| 115 |
+
```python
|
| 116 |
+
>>> from transformers import (
|
| 117 |
+
... InstructBlipVisionConfig,
|
| 118 |
+
... InstructBlipQFormerConfig,
|
| 119 |
+
... OPTConfig,
|
| 120 |
+
... InstructBlipConfig,
|
| 121 |
+
... InstructBlipForConditionalGeneration,
|
| 122 |
+
... )
|
| 123 |
+
|
| 124 |
+
>>> # Initializing a InstructBlipConfig with Salesforce/instructblip-flan-t5-xl style configuration
|
| 125 |
+
>>> configuration = InstructBlipConfig()
|
| 126 |
+
|
| 127 |
+
>>> # Initializing a InstructBlipForConditionalGeneration (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
|
| 128 |
+
>>> model = InstructBlipForConditionalGeneration(configuration)
|
| 129 |
+
|
| 130 |
+
>>> # Accessing the model configuration
|
| 131 |
+
>>> configuration = model.config
|
| 132 |
+
|
| 133 |
+
>>> # We can also initialize a InstructBlipConfig from a InstructBlipVisionConfig, InstructBlipQFormerConfig and any PreTrainedConfig
|
| 134 |
+
|
| 135 |
+
>>> # Initializing InstructBLIP vision, InstructBLIP Q-Former and language model configurations
|
| 136 |
+
>>> vision_config = InstructBlipVisionConfig()
|
| 137 |
+
>>> qformer_config = InstructBlipQFormerConfig()
|
| 138 |
+
>>> text_config = OPTConfig()
|
| 139 |
+
|
| 140 |
+
>>> config = InstructBlipConfig(vision_config=vision_config, qformer_config=qformer_config, text_config=text_config)
|
| 141 |
+
```"""
|
| 142 |
+
|
| 143 |
+
model_type = "instructblip"
|
| 144 |
+
attribute_map = {
|
| 145 |
+
"image_token_id": "image_token_index",
|
| 146 |
+
}
|
| 147 |
+
sub_configs = {
|
| 148 |
+
"text_config": AutoConfig,
|
| 149 |
+
"qformer_config": InstructBlipQFormerConfig,
|
| 150 |
+
"vision_config": InstructBlipVisionConfig,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
vision_config: dict | PreTrainedConfig | None = None
|
| 154 |
+
qformer_config: dict | PreTrainedConfig | None = None
|
| 155 |
+
text_config: dict | PreTrainedConfig | None = None
|
| 156 |
+
num_query_tokens: int = 32
|
| 157 |
+
image_token_index: int | None = None
|
| 158 |
+
initializer_factor: float = 1.0
|
| 159 |
+
initializer_range: float = 0.02
|
| 160 |
+
|
| 161 |
+
def __post_init__(self, **kwargs):
|
| 162 |
+
if self.text_config is None:
|
| 163 |
+
self.text_config = CONFIG_MAPPING["opt"]()
|
| 164 |
+
logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
|
| 165 |
+
elif isinstance(self.text_config, dict):
|
| 166 |
+
text_model_type = self.text_config.get("model_type", "opt")
|
| 167 |
+
self.text_config = CONFIG_MAPPING[text_model_type](**self.text_config)
|
| 168 |
+
|
| 169 |
+
if self.qformer_config is None:
|
| 170 |
+
self.qformer_config = InstructBlipQFormerConfig()
|
| 171 |
+
logger.info("qformer_config is None. Initializing the InstructBlipQFormerConfig with default values.")
|
| 172 |
+
elif isinstance(self.qformer_config, dict):
|
| 173 |
+
self.qformer_config = InstructBlipQFormerConfig(**self.qformer_config)
|
| 174 |
+
|
| 175 |
+
if self.vision_config is None:
|
| 176 |
+
self.vision_config = InstructBlipVisionConfig()
|
| 177 |
+
logger.info("`vision_config` is `None`. initializing the `InstructBlipVisionConfig` with default values.")
|
| 178 |
+
elif isinstance(self.vision_config, dict):
|
| 179 |
+
self.vision_config = InstructBlipVisionConfig(**self.vision_config)
|
| 180 |
+
|
| 181 |
+
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
| 182 |
+
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 183 |
+
super().__post_init__(**kwargs)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
__all__ = ["InstructBlipConfig", "InstructBlipQFormerConfig", "InstructBlipVisionConfig"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/modeling_instructblip.py
ADDED
|
@@ -0,0 +1,1405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""PyTorch InstructBLIP model."""
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from collections.abc import Callable
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
|
| 24 |
+
from ... import initialization as init
|
| 25 |
+
from ...activations import ACT2FN
|
| 26 |
+
from ...generation import GenerationMixin
|
| 27 |
+
from ...masking_utils import create_bidirectional_mask
|
| 28 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 29 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 30 |
+
from ...modeling_outputs import (
|
| 31 |
+
BaseModelOutput,
|
| 32 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 33 |
+
BaseModelOutputWithPooling,
|
| 34 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 35 |
+
CausalLMOutputWithPast,
|
| 36 |
+
Seq2SeqLMOutput,
|
| 37 |
+
)
|
| 38 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 39 |
+
from ...processing_utils import Unpack
|
| 40 |
+
from ...pytorch_utils import apply_chunking_to_forward
|
| 41 |
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
|
| 42 |
+
from ...utils.generic import merge_with_config_defaults
|
| 43 |
+
from ...utils.output_capturing import OutputRecorder, capture_outputs
|
| 44 |
+
from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
| 45 |
+
from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@auto_docstring
|
| 52 |
+
@dataclass
|
| 53 |
+
class BaseModelOutputWithVisionQformerOutputs(BaseModelOutputWithPooling):
|
| 54 |
+
r"""
|
| 55 |
+
vision_outputs (`BaseModelOutputWithPooling`):
|
| 56 |
+
Outputs of the vision encoder.
|
| 57 |
+
qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
| 58 |
+
Outputs of the Q-Former (Querying Transformer).
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
vision_outputs: BaseModelOutputWithPooling | None = None
|
| 62 |
+
qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@auto_docstring(
|
| 66 |
+
custom_intro="""
|
| 67 |
+
Class defining the outputs of [`InstructBlipForConditionalGeneration`].
|
| 68 |
+
"""
|
| 69 |
+
)
|
| 70 |
+
@dataclass
|
| 71 |
+
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlip
|
| 72 |
+
class InstructBlipForConditionalGenerationModelOutput(ModelOutput):
|
| 73 |
+
r"""
|
| 74 |
+
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
| 75 |
+
Language modeling loss from the language model.
|
| 76 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 77 |
+
Prediction scores of the language modeling head of the language model.
|
| 78 |
+
vision_outputs (`BaseModelOutputWithPooling`):
|
| 79 |
+
Outputs of the vision encoder.
|
| 80 |
+
qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
| 81 |
+
Outputs of the Q-Former (Querying Transformer).
|
| 82 |
+
language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
|
| 83 |
+
Outputs of the language model.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
loss: tuple[torch.FloatTensor] | None = None
|
| 87 |
+
logits: tuple[torch.FloatTensor] | None = None
|
| 88 |
+
vision_outputs: BaseModelOutputWithPooling | None = None
|
| 89 |
+
qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
|
| 90 |
+
language_model_outputs: CausalLMOutputWithPast | Seq2SeqLMOutput | None = None
|
| 91 |
+
|
| 92 |
+
def to_tuple(self) -> tuple[Any]:
|
| 93 |
+
return tuple(
|
| 94 |
+
self[k]
|
| 95 |
+
if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
|
| 96 |
+
else getattr(self, k).to_tuple()
|
| 97 |
+
for k in self.keys()
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlip
|
| 102 |
+
class InstructBlipVisionEmbeddings(nn.Module):
|
| 103 |
+
def __init__(self, config: InstructBlipVisionConfig):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.config = config
|
| 106 |
+
self.embed_dim = config.hidden_size
|
| 107 |
+
self.image_size = config.image_size
|
| 108 |
+
self.patch_size = config.patch_size
|
| 109 |
+
|
| 110 |
+
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
| 111 |
+
|
| 112 |
+
self.patch_embedding = nn.Conv2d(
|
| 113 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 117 |
+
self.num_positions = self.num_patches + 1
|
| 118 |
+
|
| 119 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 120 |
+
|
| 121 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 124 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 125 |
+
|
| 126 |
+
Adapted from:
|
| 127 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 128 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
num_patches = embeddings.shape[1] - 1
|
| 132 |
+
num_positions = self.position_embedding.shape[1] - 1
|
| 133 |
+
|
| 134 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 135 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 136 |
+
return self.position_embedding
|
| 137 |
+
|
| 138 |
+
class_pos_embed = self.position_embedding[:, :1]
|
| 139 |
+
patch_pos_embed = self.position_embedding[:, 1:]
|
| 140 |
+
|
| 141 |
+
dim = embeddings.shape[-1]
|
| 142 |
+
|
| 143 |
+
new_height = height // self.patch_size
|
| 144 |
+
new_width = width // self.patch_size
|
| 145 |
+
|
| 146 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 147 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 148 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 149 |
+
|
| 150 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 151 |
+
patch_pos_embed,
|
| 152 |
+
size=(new_height, new_width),
|
| 153 |
+
mode="bicubic",
|
| 154 |
+
align_corners=False,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 158 |
+
|
| 159 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 160 |
+
|
| 161 |
+
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
| 162 |
+
batch_size, _, height, width = pixel_values.shape
|
| 163 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 164 |
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
| 165 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 166 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 167 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 168 |
+
if interpolate_pos_encoding:
|
| 169 |
+
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
|
| 170 |
+
else:
|
| 171 |
+
position_embedding = self.position_embedding
|
| 172 |
+
embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
| 173 |
+
return embeddings
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32
|
| 177 |
+
def eager_attention_forward(
|
| 178 |
+
module: nn.Module,
|
| 179 |
+
query: torch.Tensor,
|
| 180 |
+
key: torch.Tensor,
|
| 181 |
+
value: torch.Tensor,
|
| 182 |
+
attention_mask: torch.Tensor | None,
|
| 183 |
+
scaling: float,
|
| 184 |
+
dropout: float = 0.0,
|
| 185 |
+
**kwargs,
|
| 186 |
+
):
|
| 187 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
| 188 |
+
if attention_mask is not None:
|
| 189 |
+
attn_weights = attn_weights + attention_mask
|
| 190 |
+
|
| 191 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 192 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 193 |
+
|
| 194 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 195 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 196 |
+
|
| 197 |
+
return attn_output, attn_weights
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip
|
| 201 |
+
class InstructBlipAttention(nn.Module):
|
| 202 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 203 |
+
|
| 204 |
+
def __init__(self, config):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.config = config
|
| 207 |
+
self.embed_dim = config.hidden_size
|
| 208 |
+
self.num_heads = config.num_attention_heads
|
| 209 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 210 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 213 |
+
f" {self.num_heads})."
|
| 214 |
+
)
|
| 215 |
+
self.scale = self.head_dim**-0.5
|
| 216 |
+
self.is_causal = False
|
| 217 |
+
self.attention_dropout = config.attention_dropout
|
| 218 |
+
|
| 219 |
+
# small tweak here compared to CLIP, no bias here
|
| 220 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
|
| 221 |
+
|
| 222 |
+
if config.qkv_bias:
|
| 223 |
+
q_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
| 224 |
+
v_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
| 225 |
+
else:
|
| 226 |
+
q_bias = None
|
| 227 |
+
v_bias = None
|
| 228 |
+
|
| 229 |
+
if q_bias is not None:
|
| 230 |
+
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
| 231 |
+
self.qkv.bias = nn.Parameter(qkv_bias)
|
| 232 |
+
|
| 233 |
+
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
|
| 234 |
+
|
| 235 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 236 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 237 |
+
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
hidden_states: torch.Tensor,
|
| 241 |
+
**kwargs,
|
| 242 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
| 243 |
+
"""Input shape: Batch x Time x Channel"""
|
| 244 |
+
|
| 245 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
| 246 |
+
|
| 247 |
+
mixed_qkv = self.qkv(hidden_states)
|
| 248 |
+
|
| 249 |
+
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
|
| 250 |
+
2, 0, 3, 1, 4
|
| 251 |
+
)
|
| 252 |
+
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
|
| 253 |
+
|
| 254 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 255 |
+
self.config._attn_implementation, eager_attention_forward
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
attn_output, attn_weights = attention_interface(
|
| 259 |
+
self,
|
| 260 |
+
query_states,
|
| 261 |
+
key_states,
|
| 262 |
+
value_states,
|
| 263 |
+
attention_mask=None,
|
| 264 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 265 |
+
scaling=self.scale,
|
| 266 |
+
**kwargs,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
| 270 |
+
attn_output = self.projection(attn_output)
|
| 271 |
+
|
| 272 |
+
return attn_output, attn_weights
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Copied from transformers.models.blip.modeling_blip.BlipMLP
|
| 276 |
+
class InstructBlipMLP(nn.Module):
|
| 277 |
+
def __init__(self, config):
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.config = config
|
| 280 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 281 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 282 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 283 |
+
|
| 284 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 285 |
+
hidden_states = self.fc1(hidden_states)
|
| 286 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 287 |
+
hidden_states = self.fc2(hidden_states)
|
| 288 |
+
return hidden_states
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
|
| 292 |
+
class InstructBlipEncoderLayer(GradientCheckpointingLayer):
|
| 293 |
+
def __init__(self, config: InstructBlipConfig):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.embed_dim = config.hidden_size
|
| 296 |
+
self.self_attn = InstructBlipAttention(config)
|
| 297 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 298 |
+
self.mlp = InstructBlipMLP(config)
|
| 299 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 300 |
+
|
| 301 |
+
@auto_docstring
|
| 302 |
+
def forward(
|
| 303 |
+
self,
|
| 304 |
+
hidden_states: torch.Tensor,
|
| 305 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 306 |
+
) -> torch.FloatTensor:
|
| 307 |
+
residual = hidden_states
|
| 308 |
+
|
| 309 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 310 |
+
hidden_states, _ = self.self_attn(
|
| 311 |
+
hidden_states=hidden_states,
|
| 312 |
+
**kwargs,
|
| 313 |
+
)
|
| 314 |
+
hidden_states = hidden_states + residual
|
| 315 |
+
residual = hidden_states
|
| 316 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 317 |
+
hidden_states = self.mlp(hidden_states)
|
| 318 |
+
|
| 319 |
+
hidden_states = hidden_states + residual
|
| 320 |
+
|
| 321 |
+
return hidden_states
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@auto_docstring
|
| 325 |
+
class InstructBlipPreTrainedModel(PreTrainedModel):
|
| 326 |
+
config: InstructBlipConfig
|
| 327 |
+
base_model_prefix = "blip"
|
| 328 |
+
input_modalities = ("image", "text")
|
| 329 |
+
supports_gradient_checkpointing = True
|
| 330 |
+
_supports_attention_backend = True
|
| 331 |
+
_supports_flash_attn = True
|
| 332 |
+
_supports_sdpa = True
|
| 333 |
+
_supports_flex_attn = True
|
| 334 |
+
|
| 335 |
+
_can_compile_fullgraph = True
|
| 336 |
+
|
| 337 |
+
_no_split_modules = [
|
| 338 |
+
"InstructBlipQFormerEmbeddings",
|
| 339 |
+
"InstructBlipAttention",
|
| 340 |
+
"InstructBlipQFormerMultiHeadAttention",
|
| 341 |
+
"InstructBlipQFormerSelfOutput",
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
@torch.no_grad()
|
| 345 |
+
def _init_weights(self, module):
|
| 346 |
+
"""Initialize the weights"""
|
| 347 |
+
super()._init_weights(module)
|
| 348 |
+
factor = self.config.initializer_range
|
| 349 |
+
if isinstance(module, InstructBlipVisionEmbeddings):
|
| 350 |
+
init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
| 351 |
+
init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
| 352 |
+
elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)):
|
| 353 |
+
init.zeros_(module.query_tokens)
|
| 354 |
+
elif isinstance(module, InstructBlipQFormerEmbeddings):
|
| 355 |
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip
|
| 359 |
+
class InstructBlipEncoder(nn.Module):
|
| 360 |
+
"""
|
| 361 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 362 |
+
[`InstructBlipEncoderLayer`].
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
config (`InstructBlipConfig`):
|
| 366 |
+
The corresponding vision configuration for the `InstructBlipEncoder`.
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
def __init__(self, config: InstructBlipConfig):
|
| 370 |
+
super().__init__()
|
| 371 |
+
self.config = config
|
| 372 |
+
self.layers = nn.ModuleList([InstructBlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 373 |
+
self.gradient_checkpointing = False
|
| 374 |
+
|
| 375 |
+
@auto_docstring
|
| 376 |
+
def forward(
|
| 377 |
+
self,
|
| 378 |
+
inputs_embeds,
|
| 379 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 380 |
+
) -> tuple | BaseModelOutput:
|
| 381 |
+
hidden_states = inputs_embeds
|
| 382 |
+
for encoder_layer in self.layers:
|
| 383 |
+
hidden_states = encoder_layer(
|
| 384 |
+
hidden_states,
|
| 385 |
+
**kwargs,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class InstructBlipVisionModel(InstructBlipPreTrainedModel):
|
| 392 |
+
main_input_name = "pixel_values"
|
| 393 |
+
input_modalities = ("image",)
|
| 394 |
+
config: InstructBlipVisionConfig
|
| 395 |
+
_can_record_outputs = {
|
| 396 |
+
"hidden_states": InstructBlipEncoderLayer,
|
| 397 |
+
"attentions": InstructBlipAttention,
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
def __init__(self, config: InstructBlipVisionConfig):
|
| 401 |
+
super().__init__(config)
|
| 402 |
+
self.config = config
|
| 403 |
+
embed_dim = config.hidden_size
|
| 404 |
+
|
| 405 |
+
self.embeddings = InstructBlipVisionEmbeddings(config)
|
| 406 |
+
self.encoder = InstructBlipEncoder(config)
|
| 407 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 408 |
+
|
| 409 |
+
self.post_init()
|
| 410 |
+
|
| 411 |
+
@merge_with_config_defaults
|
| 412 |
+
@capture_outputs(tie_last_hidden_states=False)
|
| 413 |
+
@auto_docstring
|
| 414 |
+
def forward(
|
| 415 |
+
self,
|
| 416 |
+
pixel_values: torch.FloatTensor | None = None,
|
| 417 |
+
interpolate_pos_encoding: bool = False,
|
| 418 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 419 |
+
) -> tuple | BaseModelOutputWithPooling:
|
| 420 |
+
if pixel_values is None:
|
| 421 |
+
raise ValueError("You have to specify pixel_values")
|
| 422 |
+
|
| 423 |
+
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 424 |
+
|
| 425 |
+
encoder_outputs: BaseModelOutput = self.encoder(
|
| 426 |
+
inputs_embeds=hidden_states,
|
| 427 |
+
**kwargs,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
| 431 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 432 |
+
|
| 433 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 434 |
+
pooled_output = self.post_layernorm(pooled_output)
|
| 435 |
+
|
| 436 |
+
return BaseModelOutputWithPooling(
|
| 437 |
+
last_hidden_state=last_hidden_state,
|
| 438 |
+
pooler_output=pooled_output,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
def get_input_embeddings(self):
|
| 442 |
+
return self.embeddings
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class InstructBlipQFormerMultiHeadAttention(nn.Module):
|
| 446 |
+
def __init__(self, config, is_cross_attention=False):
|
| 447 |
+
super().__init__()
|
| 448 |
+
self.config = config
|
| 449 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 450 |
+
raise ValueError(
|
| 451 |
+
"The hidden size (%d) is not a multiple of the number of attention heads (%d)"
|
| 452 |
+
% (config.hidden_size, config.num_attention_heads)
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
self.num_attention_heads = config.num_attention_heads
|
| 456 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 457 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 458 |
+
|
| 459 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 460 |
+
if is_cross_attention:
|
| 461 |
+
self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
| 462 |
+
self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
| 463 |
+
else:
|
| 464 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 465 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 466 |
+
|
| 467 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 468 |
+
self.save_attention = False
|
| 469 |
+
|
| 470 |
+
def save_attn_gradients(self, attn_gradients):
|
| 471 |
+
self.attn_gradients = attn_gradients
|
| 472 |
+
|
| 473 |
+
def get_attn_gradients(self):
|
| 474 |
+
return self.attn_gradients
|
| 475 |
+
|
| 476 |
+
def save_attention_map(self, attention_map):
|
| 477 |
+
self.attention_map = attention_map
|
| 478 |
+
|
| 479 |
+
def get_attention_map(self):
|
| 480 |
+
return self.attention_map
|
| 481 |
+
|
| 482 |
+
def transpose_for_scores(self, x):
|
| 483 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 484 |
+
x = x.view(*new_x_shape)
|
| 485 |
+
return x.permute(0, 2, 1, 3)
|
| 486 |
+
|
| 487 |
+
def forward(
|
| 488 |
+
self,
|
| 489 |
+
hidden_states,
|
| 490 |
+
attention_mask=None,
|
| 491 |
+
encoder_hidden_states=None,
|
| 492 |
+
encoder_attention_mask=None,
|
| 493 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 494 |
+
):
|
| 495 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 496 |
+
# and values come from an encoder; the attention mask needs to be
|
| 497 |
+
# such that the encoder's padding tokens are not attended to.
|
| 498 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 499 |
+
|
| 500 |
+
if is_cross_attention:
|
| 501 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 502 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 503 |
+
attention_mask = encoder_attention_mask
|
| 504 |
+
else:
|
| 505 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 506 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 507 |
+
|
| 508 |
+
mixed_query_layer = self.query(hidden_states)
|
| 509 |
+
|
| 510 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 511 |
+
|
| 512 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 513 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 514 |
+
|
| 515 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 516 |
+
attention_scores_dtype = attention_scores.dtype
|
| 517 |
+
|
| 518 |
+
if attention_mask is not None:
|
| 519 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 520 |
+
attention_scores = attention_scores + attention_mask
|
| 521 |
+
|
| 522 |
+
# Normalize the attention scores to probabilities.
|
| 523 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
|
| 524 |
+
|
| 525 |
+
if is_cross_attention and self.save_attention:
|
| 526 |
+
self.save_attention_map(attention_probs)
|
| 527 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 528 |
+
|
| 529 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 530 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 531 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 532 |
+
|
| 533 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 534 |
+
|
| 535 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 536 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 537 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 538 |
+
|
| 539 |
+
return context_layer, attention_probs
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipQFormer
|
| 543 |
+
class InstructBlipQFormerSelfOutput(nn.Module):
|
| 544 |
+
def __init__(self, config):
|
| 545 |
+
super().__init__()
|
| 546 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 547 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 548 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 549 |
+
|
| 550 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 551 |
+
hidden_states = self.dense(hidden_states)
|
| 552 |
+
hidden_states = self.dropout(hidden_states)
|
| 553 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 554 |
+
return hidden_states
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlip
|
| 558 |
+
class InstructBlipQFormerAttention(nn.Module):
|
| 559 |
+
def __init__(self, config, is_cross_attention=False):
|
| 560 |
+
super().__init__()
|
| 561 |
+
self.attention = InstructBlipQFormerMultiHeadAttention(config, is_cross_attention)
|
| 562 |
+
self.output = InstructBlipQFormerSelfOutput(config)
|
| 563 |
+
|
| 564 |
+
def forward(
|
| 565 |
+
self,
|
| 566 |
+
hidden_states: torch.Tensor,
|
| 567 |
+
attention_mask: torch.FloatTensor | None = None,
|
| 568 |
+
encoder_hidden_states: torch.FloatTensor | None = None,
|
| 569 |
+
encoder_attention_mask: torch.FloatTensor | None = None,
|
| 570 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 571 |
+
) -> torch.Tensor:
|
| 572 |
+
attn_output, _ = self.attention(
|
| 573 |
+
hidden_states=hidden_states,
|
| 574 |
+
attention_mask=attention_mask,
|
| 575 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 576 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 577 |
+
**kwargs,
|
| 578 |
+
)
|
| 579 |
+
attention_output = self.output(attn_output, hidden_states)
|
| 580 |
+
return attention_output
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipQFormer
|
| 584 |
+
class InstructBlipQFormerIntermediate(nn.Module):
|
| 585 |
+
def __init__(self, config):
|
| 586 |
+
super().__init__()
|
| 587 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 588 |
+
if isinstance(config.hidden_act, str):
|
| 589 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 590 |
+
else:
|
| 591 |
+
self.intermediate_act_fn = config.hidden_act
|
| 592 |
+
|
| 593 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 594 |
+
hidden_states = self.dense(hidden_states)
|
| 595 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 596 |
+
return hidden_states
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipQFormer
|
| 600 |
+
class InstructBlipQFormerOutput(nn.Module):
|
| 601 |
+
def __init__(self, config):
|
| 602 |
+
super().__init__()
|
| 603 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 604 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 605 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 606 |
+
|
| 607 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 608 |
+
hidden_states = self.dense(hidden_states)
|
| 609 |
+
hidden_states = self.dropout(hidden_states)
|
| 610 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 611 |
+
return hidden_states
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class InstructBlipQFormerLayer(GradientCheckpointingLayer):
|
| 615 |
+
def __init__(self, config, layer_idx):
|
| 616 |
+
super().__init__()
|
| 617 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 618 |
+
self.seq_len_dim = 1
|
| 619 |
+
self.attention = InstructBlipQFormerAttention(config)
|
| 620 |
+
|
| 621 |
+
self.layer_idx = layer_idx
|
| 622 |
+
|
| 623 |
+
if layer_idx % config.cross_attention_frequency == 0:
|
| 624 |
+
self.crossattention = InstructBlipQFormerAttention(config, is_cross_attention=True)
|
| 625 |
+
self.has_cross_attention = True
|
| 626 |
+
else:
|
| 627 |
+
self.has_cross_attention = False
|
| 628 |
+
|
| 629 |
+
self.intermediate = InstructBlipQFormerIntermediate(config)
|
| 630 |
+
self.output = InstructBlipQFormerOutput(config)
|
| 631 |
+
|
| 632 |
+
self.intermediate_query = InstructBlipQFormerIntermediate(config)
|
| 633 |
+
self.output_query = InstructBlipQFormerOutput(config)
|
| 634 |
+
|
| 635 |
+
def forward(
|
| 636 |
+
self,
|
| 637 |
+
hidden_states,
|
| 638 |
+
attention_mask=None,
|
| 639 |
+
encoder_hidden_states=None,
|
| 640 |
+
encoder_attention_mask=None,
|
| 641 |
+
query_length=0,
|
| 642 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 643 |
+
):
|
| 644 |
+
attention_output = self.attention(
|
| 645 |
+
hidden_states,
|
| 646 |
+
attention_mask=attention_mask,
|
| 647 |
+
**kwargs,
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
if query_length > 0:
|
| 651 |
+
query_attention_output = attention_output[:, :query_length, :]
|
| 652 |
+
|
| 653 |
+
if self.has_cross_attention:
|
| 654 |
+
if encoder_hidden_states is None:
|
| 655 |
+
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
|
| 656 |
+
query_attention_output = self.crossattention(
|
| 657 |
+
query_attention_output,
|
| 658 |
+
attention_mask=attention_mask,
|
| 659 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 660 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 661 |
+
**kwargs,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
layer_output = apply_chunking_to_forward(
|
| 665 |
+
self.feed_forward_chunk_query,
|
| 666 |
+
self.chunk_size_feed_forward,
|
| 667 |
+
self.seq_len_dim,
|
| 668 |
+
query_attention_output,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
if attention_output.shape[1] > query_length:
|
| 672 |
+
layer_output_text = apply_chunking_to_forward(
|
| 673 |
+
self.feed_forward_chunk,
|
| 674 |
+
self.chunk_size_feed_forward,
|
| 675 |
+
self.seq_len_dim,
|
| 676 |
+
attention_output[:, query_length:, :],
|
| 677 |
+
).to(layer_output.device)
|
| 678 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
| 679 |
+
else:
|
| 680 |
+
layer_output = apply_chunking_to_forward(
|
| 681 |
+
self.feed_forward_chunk,
|
| 682 |
+
self.chunk_size_feed_forward,
|
| 683 |
+
self.seq_len_dim,
|
| 684 |
+
attention_output,
|
| 685 |
+
)
|
| 686 |
+
return layer_output
|
| 687 |
+
|
| 688 |
+
def feed_forward_chunk(self, attention_output):
|
| 689 |
+
intermediate_output = self.intermediate(attention_output)
|
| 690 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 691 |
+
return layer_output
|
| 692 |
+
|
| 693 |
+
def feed_forward_chunk_query(self, attention_output):
|
| 694 |
+
intermediate_output = self.intermediate_query(attention_output)
|
| 695 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
| 696 |
+
return layer_output
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlip
|
| 700 |
+
class InstructBlipQFormerEncoder(nn.Module):
|
| 701 |
+
def __init__(self, config):
|
| 702 |
+
super().__init__()
|
| 703 |
+
self.config = config
|
| 704 |
+
self.layer = nn.ModuleList(
|
| 705 |
+
[InstructBlipQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 706 |
+
)
|
| 707 |
+
self.gradient_checkpointing = False
|
| 708 |
+
|
| 709 |
+
@can_return_tuple
|
| 710 |
+
def forward(
|
| 711 |
+
self,
|
| 712 |
+
hidden_states,
|
| 713 |
+
attention_mask=None,
|
| 714 |
+
encoder_hidden_states=None,
|
| 715 |
+
encoder_attention_mask=None,
|
| 716 |
+
query_length=0,
|
| 717 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 718 |
+
):
|
| 719 |
+
for i in range(self.config.num_hidden_layers):
|
| 720 |
+
layer_module = self.layer[i]
|
| 721 |
+
|
| 722 |
+
hidden_states = layer_module(
|
| 723 |
+
hidden_states,
|
| 724 |
+
attention_mask,
|
| 725 |
+
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
| 726 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 727 |
+
query_length=query_length,
|
| 728 |
+
**kwargs,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 732 |
+
last_hidden_state=hidden_states,
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
class InstructBlipQFormerEmbeddings(nn.Module):
|
| 737 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 738 |
+
|
| 739 |
+
def __init__(self, config):
|
| 740 |
+
super().__init__()
|
| 741 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 742 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 743 |
+
|
| 744 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 745 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 746 |
+
|
| 747 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 748 |
+
self.register_buffer(
|
| 749 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
self.config = config
|
| 753 |
+
|
| 754 |
+
def forward(
|
| 755 |
+
self,
|
| 756 |
+
input_ids=None,
|
| 757 |
+
position_ids=None,
|
| 758 |
+
query_embeds=None,
|
| 759 |
+
past_key_values_length=0,
|
| 760 |
+
):
|
| 761 |
+
if input_ids is not None:
|
| 762 |
+
seq_length = input_ids.size()[1]
|
| 763 |
+
else:
|
| 764 |
+
seq_length = 0
|
| 765 |
+
|
| 766 |
+
if position_ids is None:
|
| 767 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
|
| 768 |
+
|
| 769 |
+
if input_ids is not None:
|
| 770 |
+
embeddings = self.word_embeddings(input_ids)
|
| 771 |
+
|
| 772 |
+
position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
|
| 773 |
+
embeddings = embeddings + position_embeddings
|
| 774 |
+
|
| 775 |
+
if query_embeds is not None:
|
| 776 |
+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
| 777 |
+
else:
|
| 778 |
+
embeddings = query_embeds
|
| 779 |
+
|
| 780 |
+
embeddings = embeddings.to(self.layernorm.weight.dtype)
|
| 781 |
+
embeddings = self.layernorm(embeddings)
|
| 782 |
+
embeddings = self.dropout(embeddings)
|
| 783 |
+
return embeddings
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
|
| 787 |
+
"""
|
| 788 |
+
Querying Transformer (Q-Former), used in InstructBLIP. Slightly modified from BLIP-2 as it also takes the
|
| 789 |
+
instruction as input.
|
| 790 |
+
"""
|
| 791 |
+
|
| 792 |
+
_supports_attention_backend = False # adds position on attn weights before last matmul
|
| 793 |
+
_supports_flash_attn = False
|
| 794 |
+
_supports_sdpa = False
|
| 795 |
+
_supports_flex_attn = False
|
| 796 |
+
|
| 797 |
+
_can_record_outputs = {
|
| 798 |
+
"hidden_states": InstructBlipQFormerLayer,
|
| 799 |
+
"attentions": [
|
| 800 |
+
OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".attention"),
|
| 801 |
+
],
|
| 802 |
+
"cross_attentions": [
|
| 803 |
+
OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".crossattention"),
|
| 804 |
+
],
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
def __init__(self, config: InstructBlipQFormerConfig):
|
| 808 |
+
super().__init__(config)
|
| 809 |
+
self.config = config
|
| 810 |
+
|
| 811 |
+
self.embeddings = InstructBlipQFormerEmbeddings(config)
|
| 812 |
+
|
| 813 |
+
self.encoder = InstructBlipQFormerEncoder(config)
|
| 814 |
+
|
| 815 |
+
self.post_init()
|
| 816 |
+
|
| 817 |
+
def get_input_embeddings(self):
|
| 818 |
+
return self.embeddings.word_embeddings
|
| 819 |
+
|
| 820 |
+
def set_input_embeddings(self, value):
|
| 821 |
+
self.embeddings.word_embeddings = value
|
| 822 |
+
|
| 823 |
+
@merge_with_config_defaults
|
| 824 |
+
@capture_outputs
|
| 825 |
+
@auto_docstring
|
| 826 |
+
def forward(
|
| 827 |
+
self,
|
| 828 |
+
input_ids: torch.LongTensor,
|
| 829 |
+
attention_mask: torch.FloatTensor | None = None,
|
| 830 |
+
position_ids: torch.LongTensor | None = None,
|
| 831 |
+
query_embeds: torch.Tensor | None = None,
|
| 832 |
+
encoder_hidden_states: torch.FloatTensor | None = None,
|
| 833 |
+
encoder_attention_mask: torch.FloatTensor | None = None,
|
| 834 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 835 |
+
) -> tuple[torch.FloatTensor] | BaseModelOutputWithPoolingAndCrossAttentions:
|
| 836 |
+
r"""
|
| 837 |
+
query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 838 |
+
Hidden states to be used in the attention computation. If cross-attention,
|
| 839 |
+
will be used for the query (i.e., key and value will use the encoder_hidden_states).
|
| 840 |
+
"""
|
| 841 |
+
if input_ids is None and query_embeds is None:
|
| 842 |
+
raise ValueError("You have to specify query_embeds when input_ids is None")
|
| 843 |
+
|
| 844 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
| 845 |
+
|
| 846 |
+
embedding_output = self.embeddings(
|
| 847 |
+
input_ids=input_ids,
|
| 848 |
+
position_ids=position_ids,
|
| 849 |
+
query_embeds=query_embeds,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
attention_mask = create_bidirectional_mask(
|
| 853 |
+
config=self.config,
|
| 854 |
+
inputs_embeds=embedding_output,
|
| 855 |
+
attention_mask=attention_mask,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
if encoder_attention_mask is not None:
|
| 859 |
+
encoder_attention_mask = create_bidirectional_mask(
|
| 860 |
+
config=self.config,
|
| 861 |
+
inputs_embeds=embedding_output,
|
| 862 |
+
attention_mask=encoder_attention_mask,
|
| 863 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
encoder_outputs: BaseModelOutput = self.encoder(
|
| 867 |
+
embedding_output,
|
| 868 |
+
attention_mask=attention_mask,
|
| 869 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 870 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 871 |
+
query_length=query_length,
|
| 872 |
+
**kwargs,
|
| 873 |
+
)
|
| 874 |
+
sequence_output = encoder_outputs.last_hidden_state
|
| 875 |
+
pooled_output = sequence_output[:, 0, :]
|
| 876 |
+
|
| 877 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 878 |
+
last_hidden_state=sequence_output,
|
| 879 |
+
pooler_output=pooled_output,
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
@auto_docstring(
|
| 884 |
+
custom_intro="""
|
| 885 |
+
InstructBLIP base Model consisting of language model, qformer and vision encoder.
|
| 886 |
+
"""
|
| 887 |
+
)
|
| 888 |
+
class InstructBlipModel(InstructBlipPreTrainedModel):
|
| 889 |
+
main_input_name = "pixel_values"
|
| 890 |
+
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
|
| 891 |
+
|
| 892 |
+
def __init__(self, config: InstructBlipConfig):
|
| 893 |
+
super().__init__(config)
|
| 894 |
+
|
| 895 |
+
self.vision_model = InstructBlipVisionModel(config.vision_config)
|
| 896 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
| 897 |
+
self.qformer = InstructBlipQFormerModel(config.qformer_config)
|
| 898 |
+
|
| 899 |
+
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
| 900 |
+
self.language_model = AutoModel.from_config(config.text_config)
|
| 901 |
+
|
| 902 |
+
# Initialize weights and apply final processing
|
| 903 |
+
self.post_init()
|
| 904 |
+
|
| 905 |
+
def _preprocess_accelerate(self):
|
| 906 |
+
r"""
|
| 907 |
+
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
| 908 |
+
https://github.com/huggingface/transformers/pull/21707 for more details.
|
| 909 |
+
"""
|
| 910 |
+
hf_device_map = self.hf_device_map
|
| 911 |
+
|
| 912 |
+
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
| 913 |
+
# warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
|
| 914 |
+
logger.warning(
|
| 915 |
+
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
| 916 |
+
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
| 917 |
+
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
| 918 |
+
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
|
| 919 |
+
" more details on creating a `device_map` for large models.",
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
if hasattr(self.language_model, "_hf_hook"):
|
| 923 |
+
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
| 924 |
+
|
| 925 |
+
def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
|
| 926 |
+
"""
|
| 927 |
+
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
|
| 928 |
+
"""
|
| 929 |
+
if input_ids is None:
|
| 930 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 931 |
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 932 |
+
)
|
| 933 |
+
special_image_mask = special_image_mask.all(-1)
|
| 934 |
+
else:
|
| 935 |
+
special_image_mask = input_ids == self.config.image_token_id
|
| 936 |
+
|
| 937 |
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 938 |
+
return special_image_mask
|
| 939 |
+
|
| 940 |
+
@can_return_tuple
|
| 941 |
+
@auto_docstring
|
| 942 |
+
def forward(
|
| 943 |
+
self,
|
| 944 |
+
pixel_values: torch.FloatTensor,
|
| 945 |
+
qformer_input_ids: torch.FloatTensor,
|
| 946 |
+
qformer_attention_mask: torch.LongTensor | None = None,
|
| 947 |
+
input_ids: torch.FloatTensor | None = None,
|
| 948 |
+
attention_mask: torch.LongTensor | None = None,
|
| 949 |
+
decoder_input_ids: torch.LongTensor | None = None,
|
| 950 |
+
decoder_attention_mask: torch.LongTensor | None = None,
|
| 951 |
+
inputs_embeds: torch.Tensor | None = None,
|
| 952 |
+
interpolate_pos_encoding: bool = False,
|
| 953 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 954 |
+
) -> tuple | InstructBlipForConditionalGenerationModelOutput:
|
| 955 |
+
r"""
|
| 956 |
+
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 957 |
+
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
| 958 |
+
to serve as text prompt, which the Q-Former model will encode.
|
| 959 |
+
|
| 960 |
+
Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
|
| 961 |
+
details.
|
| 962 |
+
|
| 963 |
+
[What are input IDs?](../glossary#input-ids)
|
| 964 |
+
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 965 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 966 |
+
|
| 967 |
+
- 1 for tokens that are **not masked**,
|
| 968 |
+
- 0 for tokens that are **masked**.
|
| 969 |
+
|
| 970 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 971 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 972 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
| 973 |
+
be used by default.
|
| 974 |
+
|
| 975 |
+
Only relevant in case an encoder-decoder language model (like T5) is used.
|
| 976 |
+
"""
|
| 977 |
+
|
| 978 |
+
# step 1: forward the images through the vision encoder,
|
| 979 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 980 |
+
vision_outputs = self.vision_model(
|
| 981 |
+
pixel_values=pixel_values,
|
| 982 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 983 |
+
**kwargs,
|
| 984 |
+
)
|
| 985 |
+
image_embeds = vision_outputs[0]
|
| 986 |
+
|
| 987 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 988 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 989 |
+
|
| 990 |
+
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
| 991 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 992 |
+
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 993 |
+
if qformer_attention_mask is None:
|
| 994 |
+
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
| 995 |
+
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
| 996 |
+
query_outputs = self.qformer(
|
| 997 |
+
input_ids=qformer_input_ids,
|
| 998 |
+
attention_mask=qformer_attention_mask,
|
| 999 |
+
query_embeds=query_tokens,
|
| 1000 |
+
encoder_hidden_states=image_embeds,
|
| 1001 |
+
encoder_attention_mask=image_attention_mask,
|
| 1002 |
+
**kwargs,
|
| 1003 |
+
)
|
| 1004 |
+
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
| 1005 |
+
|
| 1006 |
+
if inputs_embeds is None:
|
| 1007 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1008 |
+
if attention_mask is None:
|
| 1009 |
+
attention_mask = torch.ones_like(input_ids)
|
| 1010 |
+
|
| 1011 |
+
# step 3: use the language model, conditioned on the query outputs and the prompt
|
| 1012 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1013 |
+
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 1014 |
+
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
|
| 1015 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
| 1016 |
+
|
| 1017 |
+
if self.config.use_decoder_only_language_model:
|
| 1018 |
+
outputs = self.language_model(
|
| 1019 |
+
inputs_embeds=inputs_embeds,
|
| 1020 |
+
attention_mask=attention_mask,
|
| 1021 |
+
**kwargs,
|
| 1022 |
+
)
|
| 1023 |
+
else:
|
| 1024 |
+
outputs = self.language_model(
|
| 1025 |
+
inputs_embeds=inputs_embeds,
|
| 1026 |
+
attention_mask=attention_mask,
|
| 1027 |
+
decoder_input_ids=decoder_input_ids,
|
| 1028 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 1029 |
+
**kwargs,
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
return InstructBlipForConditionalGenerationModelOutput(
|
| 1033 |
+
vision_outputs=vision_outputs,
|
| 1034 |
+
qformer_outputs=query_outputs,
|
| 1035 |
+
language_model_outputs=outputs,
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
@auto_docstring(
|
| 1040 |
+
custom_intro="""
|
| 1041 |
+
InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision
|
| 1042 |
+
encoder, Querying Transformer (Q-Former) and a language model.
|
| 1043 |
+
|
| 1044 |
+
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
| 1045 |
+
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
|
| 1046 |
+
"""
|
| 1047 |
+
)
|
| 1048 |
+
class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
|
| 1049 |
+
config: InstructBlipConfig
|
| 1050 |
+
main_input_name = "pixel_values"
|
| 1051 |
+
|
| 1052 |
+
_can_compile_fullgraph = True
|
| 1053 |
+
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
|
| 1054 |
+
|
| 1055 |
+
def __init__(self, config: InstructBlipConfig):
|
| 1056 |
+
super().__init__(config)
|
| 1057 |
+
|
| 1058 |
+
self.vision_model = InstructBlipVisionModel._from_config(config.vision_config)
|
| 1059 |
+
|
| 1060 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
| 1061 |
+
self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config)
|
| 1062 |
+
|
| 1063 |
+
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
| 1064 |
+
|
| 1065 |
+
if config.use_decoder_only_language_model:
|
| 1066 |
+
language_model = AutoModelForCausalLM.from_config(config.text_config)
|
| 1067 |
+
else:
|
| 1068 |
+
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
| 1069 |
+
|
| 1070 |
+
self.language_model = language_model
|
| 1071 |
+
|
| 1072 |
+
# Initialize weights and apply final processing
|
| 1073 |
+
self.post_init()
|
| 1074 |
+
|
| 1075 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1076 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 1077 |
+
|
| 1078 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 1079 |
+
return self.language_model.get_output_embeddings()
|
| 1080 |
+
|
| 1081 |
+
def get_encoder(self, modality=None):
|
| 1082 |
+
if modality is None:
|
| 1083 |
+
return self.language_model.get_encoder()
|
| 1084 |
+
else:
|
| 1085 |
+
return super().get_encoder(modality=modality)
|
| 1086 |
+
|
| 1087 |
+
def get_decoder(self):
|
| 1088 |
+
return self.language_model.get_decoder()
|
| 1089 |
+
|
| 1090 |
+
# Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate
|
| 1091 |
+
def _preprocess_accelerate(self):
|
| 1092 |
+
r"""
|
| 1093 |
+
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
| 1094 |
+
https://github.com/huggingface/transformers/pull/21707 for more details.
|
| 1095 |
+
"""
|
| 1096 |
+
hf_device_map = self.hf_device_map
|
| 1097 |
+
|
| 1098 |
+
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
| 1099 |
+
# warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
|
| 1100 |
+
logger.warning(
|
| 1101 |
+
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
| 1102 |
+
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
| 1103 |
+
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
| 1104 |
+
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
|
| 1105 |
+
" more details on creating a `device_map` for large models.",
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
if hasattr(self.language_model, "_hf_hook"):
|
| 1109 |
+
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
| 1110 |
+
|
| 1111 |
+
@can_return_tuple
|
| 1112 |
+
@auto_docstring
|
| 1113 |
+
def get_image_features(
|
| 1114 |
+
self,
|
| 1115 |
+
pixel_values: torch.FloatTensor,
|
| 1116 |
+
qformer_input_ids: torch.LongTensor,
|
| 1117 |
+
qformer_attention_mask: torch.LongTensor | None = None,
|
| 1118 |
+
interpolate_pos_encoding: bool | None = False,
|
| 1119 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1120 |
+
) -> tuple | BaseModelOutputWithVisionQformerOutputs:
|
| 1121 |
+
r"""
|
| 1122 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
| 1123 |
+
The tensors corresponding to the input images.
|
| 1124 |
+
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1125 |
+
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
| 1126 |
+
to serve as text prompt, which the Q-Former model will encode.
|
| 1127 |
+
|
| 1128 |
+
Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
|
| 1129 |
+
details.
|
| 1130 |
+
|
| 1131 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1132 |
+
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1133 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1134 |
+
|
| 1135 |
+
- 1 for tokens that are **not masked**,
|
| 1136 |
+
- 0 for tokens that are **masked**.
|
| 1137 |
+
|
| 1138 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1139 |
+
"""
|
| 1140 |
+
# step 1: forward the images through the vision encoder,
|
| 1141 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 1142 |
+
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
|
| 1143 |
+
pixel_values=pixel_values,
|
| 1144 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1145 |
+
return_dict=True,
|
| 1146 |
+
**kwargs,
|
| 1147 |
+
)
|
| 1148 |
+
vision_outputs = BaseModelOutputWithVisionQformerOutputs(**vision_outputs, vision_outputs=vision_outputs)
|
| 1149 |
+
image_embeds = vision_outputs[0]
|
| 1150 |
+
|
| 1151 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1152 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1153 |
+
|
| 1154 |
+
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
| 1155 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1156 |
+
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1157 |
+
if qformer_attention_mask is None:
|
| 1158 |
+
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
| 1159 |
+
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
| 1160 |
+
qformer_outputs = self.qformer(
|
| 1161 |
+
input_ids=qformer_input_ids,
|
| 1162 |
+
attention_mask=qformer_attention_mask,
|
| 1163 |
+
query_embeds=query_tokens,
|
| 1164 |
+
encoder_hidden_states=image_embeds,
|
| 1165 |
+
encoder_attention_mask=image_attention_mask,
|
| 1166 |
+
return_dict=True,
|
| 1167 |
+
**kwargs,
|
| 1168 |
+
)
|
| 1169 |
+
vision_outputs.qformer_outputs = qformer_outputs
|
| 1170 |
+
query_output = qformer_outputs[0][:, : query_tokens.size(1), :]
|
| 1171 |
+
|
| 1172 |
+
# step 3: use the language model, conditioned on the query outputs and the prompt
|
| 1173 |
+
image_features = self.language_projection(query_output)
|
| 1174 |
+
vision_outputs.pooler_output = image_features
|
| 1175 |
+
|
| 1176 |
+
return vision_outputs
|
| 1177 |
+
|
| 1178 |
+
def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
|
| 1179 |
+
"""
|
| 1180 |
+
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
|
| 1181 |
+
"""
|
| 1182 |
+
if input_ids is None:
|
| 1183 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 1184 |
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 1185 |
+
)
|
| 1186 |
+
special_image_mask = special_image_mask.all(-1)
|
| 1187 |
+
else:
|
| 1188 |
+
special_image_mask = input_ids == self.config.image_token_id
|
| 1189 |
+
|
| 1190 |
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 1191 |
+
return special_image_mask
|
| 1192 |
+
|
| 1193 |
+
@can_return_tuple
|
| 1194 |
+
@auto_docstring
|
| 1195 |
+
def forward(
|
| 1196 |
+
self,
|
| 1197 |
+
pixel_values: torch.FloatTensor,
|
| 1198 |
+
qformer_input_ids: torch.FloatTensor,
|
| 1199 |
+
qformer_attention_mask: torch.LongTensor | None = None,
|
| 1200 |
+
input_ids: torch.FloatTensor | None = None,
|
| 1201 |
+
attention_mask: torch.LongTensor | None = None,
|
| 1202 |
+
decoder_input_ids: torch.LongTensor | None = None,
|
| 1203 |
+
decoder_attention_mask: torch.LongTensor | None = None,
|
| 1204 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 1205 |
+
labels: torch.LongTensor | None = None,
|
| 1206 |
+
interpolate_pos_encoding: bool = False,
|
| 1207 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1208 |
+
) -> tuple | InstructBlipForConditionalGenerationModelOutput:
|
| 1209 |
+
r"""
|
| 1210 |
+
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1211 |
+
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
| 1212 |
+
to serve as text prompt, which the Q-Former model will encode.
|
| 1213 |
+
|
| 1214 |
+
Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
|
| 1215 |
+
details.
|
| 1216 |
+
|
| 1217 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1218 |
+
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1219 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1220 |
+
|
| 1221 |
+
- 1 for tokens that are **not masked**,
|
| 1222 |
+
- 0 for tokens that are **masked**.
|
| 1223 |
+
|
| 1224 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1225 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 1226 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
| 1227 |
+
be used by default.
|
| 1228 |
+
|
| 1229 |
+
Only relevant in case an encoder-decoder language model (like T5) is used.
|
| 1230 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1231 |
+
Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
|
| 1232 |
+
1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 1233 |
+
config.vocab_size]`
|
| 1234 |
+
|
| 1235 |
+
Examples:
|
| 1236 |
+
|
| 1237 |
+
```python
|
| 1238 |
+
>>> from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
|
| 1239 |
+
>>> import torch
|
| 1240 |
+
>>> from PIL import Image
|
| 1241 |
+
>>> import httpx
|
| 1242 |
+
>>> from io import BytesIO
|
| 1243 |
+
|
| 1244 |
+
>>> model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
| 1245 |
+
>>> processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
| 1246 |
+
|
| 1247 |
+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1248 |
+
>>> model.to(device) # doctest: +IGNORE_RESULT
|
| 1249 |
+
|
| 1250 |
+
>>> url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
|
| 1251 |
+
>>> with httpx.stream("GET", url) as response:
|
| 1252 |
+
... image = Image.open(BytesIO(response.read())).convert("RGB")
|
| 1253 |
+
>>> prompt = "What is unusual about this image?"
|
| 1254 |
+
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
|
| 1255 |
+
|
| 1256 |
+
>>> outputs = model.generate(
|
| 1257 |
+
... **inputs,
|
| 1258 |
+
... do_sample=False,
|
| 1259 |
+
... num_beams=5,
|
| 1260 |
+
... max_length=256,
|
| 1261 |
+
... min_length=1,
|
| 1262 |
+
... top_p=0.9,
|
| 1263 |
+
... repetition_penalty=1.5,
|
| 1264 |
+
... length_penalty=1.0,
|
| 1265 |
+
... temperature=1,
|
| 1266 |
+
... )
|
| 1267 |
+
>>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
| 1268 |
+
>>> print(generated_text)
|
| 1269 |
+
The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation.
|
| 1270 |
+
```"""
|
| 1271 |
+
|
| 1272 |
+
image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features(
|
| 1273 |
+
pixel_values,
|
| 1274 |
+
qformer_input_ids=qformer_input_ids,
|
| 1275 |
+
qformer_attention_mask=qformer_attention_mask,
|
| 1276 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1277 |
+
return_dict=True,
|
| 1278 |
+
)
|
| 1279 |
+
language_model_inputs = image_features.pooler_output
|
| 1280 |
+
qformer_outputs = image_features.qformer_outputs
|
| 1281 |
+
vision_outputs = image_features.vision_outputs
|
| 1282 |
+
|
| 1283 |
+
if inputs_embeds is None:
|
| 1284 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1285 |
+
|
| 1286 |
+
if attention_mask is None:
|
| 1287 |
+
attention_mask = torch.ones_like(input_ids)
|
| 1288 |
+
|
| 1289 |
+
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 1290 |
+
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
|
| 1291 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
| 1292 |
+
|
| 1293 |
+
if self.config.use_decoder_only_language_model:
|
| 1294 |
+
outputs = self.language_model(
|
| 1295 |
+
inputs_embeds=inputs_embeds,
|
| 1296 |
+
attention_mask=attention_mask,
|
| 1297 |
+
**kwargs,
|
| 1298 |
+
)
|
| 1299 |
+
logits = outputs[0]
|
| 1300 |
+
loss = None
|
| 1301 |
+
if labels is not None:
|
| 1302 |
+
loss = self.loss_function(
|
| 1303 |
+
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
| 1304 |
+
)
|
| 1305 |
+
|
| 1306 |
+
else:
|
| 1307 |
+
kwargs["return_dict"] = True
|
| 1308 |
+
outputs = self.language_model(
|
| 1309 |
+
inputs_embeds=inputs_embeds,
|
| 1310 |
+
attention_mask=attention_mask,
|
| 1311 |
+
decoder_input_ids=decoder_input_ids,
|
| 1312 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 1313 |
+
labels=labels,
|
| 1314 |
+
**kwargs,
|
| 1315 |
+
)
|
| 1316 |
+
loss = outputs.loss
|
| 1317 |
+
logits = outputs.logits
|
| 1318 |
+
|
| 1319 |
+
return InstructBlipForConditionalGenerationModelOutput(
|
| 1320 |
+
loss=loss,
|
| 1321 |
+
logits=logits,
|
| 1322 |
+
vision_outputs=vision_outputs,
|
| 1323 |
+
qformer_outputs=qformer_outputs,
|
| 1324 |
+
language_model_outputs=outputs,
|
| 1325 |
+
)
|
| 1326 |
+
|
| 1327 |
+
@torch.no_grad()
|
| 1328 |
+
def generate(
|
| 1329 |
+
self,
|
| 1330 |
+
pixel_values: torch.FloatTensor,
|
| 1331 |
+
qformer_input_ids: torch.LongTensor | None = None,
|
| 1332 |
+
qformer_attention_mask: torch.LongTensor | None = None,
|
| 1333 |
+
input_ids: torch.LongTensor | None = None,
|
| 1334 |
+
attention_mask: torch.LongTensor | None = None,
|
| 1335 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 1336 |
+
interpolate_pos_encoding: bool = False,
|
| 1337 |
+
**generate_kwargs,
|
| 1338 |
+
) -> torch.LongTensor:
|
| 1339 |
+
"""
|
| 1340 |
+
Overrides `generate` function to be able to use the model as a conditional generator.
|
| 1341 |
+
|
| 1342 |
+
Args:
|
| 1343 |
+
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
|
| 1344 |
+
Input images to be processed.
|
| 1345 |
+
qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1346 |
+
The sequence used as a prompt to be fed to the Q-Former module.
|
| 1347 |
+
qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1348 |
+
Mask to avoid performing attention on padding token indices.
|
| 1349 |
+
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1350 |
+
The sequence used as a prompt for the generation.
|
| 1351 |
+
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1352 |
+
Mask to avoid performing attention on padding token indices.
|
| 1353 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 1354 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
| 1355 |
+
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
| 1356 |
+
Whether to interpolate the positional encoding of the image embeddings.
|
| 1357 |
+
|
| 1358 |
+
Returns:
|
| 1359 |
+
captions (list): A list of strings of length batch_size * num_captions.
|
| 1360 |
+
"""
|
| 1361 |
+
if hasattr(self, "hf_device_map"):
|
| 1362 |
+
# preprocess for `accelerate`
|
| 1363 |
+
self._preprocess_accelerate()
|
| 1364 |
+
|
| 1365 |
+
batch_size = pixel_values.shape[0]
|
| 1366 |
+
image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features(
|
| 1367 |
+
pixel_values,
|
| 1368 |
+
qformer_input_ids=qformer_input_ids,
|
| 1369 |
+
qformer_attention_mask=qformer_attention_mask,
|
| 1370 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1371 |
+
return_dict=True,
|
| 1372 |
+
)
|
| 1373 |
+
language_model_inputs = image_features.pooler_output
|
| 1374 |
+
|
| 1375 |
+
if inputs_embeds is None:
|
| 1376 |
+
if input_ids is None:
|
| 1377 |
+
image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
|
| 1378 |
+
start_tokens = image_tokens + [self.config.text_config.bos_token_id]
|
| 1379 |
+
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
| 1380 |
+
input_ids = input_ids.repeat(batch_size, 1)
|
| 1381 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1382 |
+
|
| 1383 |
+
if attention_mask is None:
|
| 1384 |
+
attention_mask = torch.ones_like(input_ids)
|
| 1385 |
+
|
| 1386 |
+
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 1387 |
+
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
|
| 1388 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
| 1389 |
+
|
| 1390 |
+
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
| 1391 |
+
if not self.language_model.config.is_encoder_decoder:
|
| 1392 |
+
inputs["input_ids"] = input_ids
|
| 1393 |
+
|
| 1394 |
+
outputs = self.language_model.generate(**inputs, **generate_kwargs)
|
| 1395 |
+
|
| 1396 |
+
return outputs
|
| 1397 |
+
|
| 1398 |
+
|
| 1399 |
+
__all__ = [
|
| 1400 |
+
"InstructBlipQFormerModel",
|
| 1401 |
+
"InstructBlipPreTrainedModel",
|
| 1402 |
+
"InstructBlipModel",
|
| 1403 |
+
"InstructBlipForConditionalGeneration",
|
| 1404 |
+
"InstructBlipVisionModel",
|
| 1405 |
+
]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/processing_instructblip.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from ...image_processing_utils import BatchFeature
|
| 19 |
+
from ...image_utils import ImageInput
|
| 20 |
+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
| 21 |
+
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
|
| 22 |
+
from ...utils import auto_docstring, logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class InstructBlipProcessorKwargs(ProcessingKwargs, total=False):
|
| 29 |
+
_defaults = {
|
| 30 |
+
"text_kwargs": {
|
| 31 |
+
"add_special_tokens": True,
|
| 32 |
+
"padding": False,
|
| 33 |
+
"stride": 0,
|
| 34 |
+
"return_overflowing_tokens": False,
|
| 35 |
+
"return_special_tokens_mask": False,
|
| 36 |
+
"return_offsets_mapping": False,
|
| 37 |
+
"return_token_type_ids": False,
|
| 38 |
+
"return_length": False,
|
| 39 |
+
"verbose": True,
|
| 40 |
+
},
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@auto_docstring
|
| 45 |
+
class InstructBlipProcessor(ProcessorMixin):
|
| 46 |
+
def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
|
| 47 |
+
r"""
|
| 48 |
+
qformer_tokenizer (`AutoTokenizer`):
|
| 49 |
+
An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
|
| 50 |
+
num_query_tokens (`int`, *optional*):
|
| 51 |
+
"
|
| 52 |
+
Number of tokens used by the Qformer as queries, should be same as in model's config.
|
| 53 |
+
"""
|
| 54 |
+
if not hasattr(tokenizer, "image_token"):
|
| 55 |
+
self.image_token = AddedToken("<image>", normalized=False, special=True)
|
| 56 |
+
tokenizer.add_tokens([self.image_token], special_tokens=True)
|
| 57 |
+
else:
|
| 58 |
+
self.image_token = tokenizer.image_token
|
| 59 |
+
self.num_query_tokens = num_query_tokens
|
| 60 |
+
|
| 61 |
+
super().__init__(image_processor, tokenizer, qformer_tokenizer)
|
| 62 |
+
|
| 63 |
+
@auto_docstring
|
| 64 |
+
def __call__(
|
| 65 |
+
self,
|
| 66 |
+
images: ImageInput | None = None,
|
| 67 |
+
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
| 68 |
+
**kwargs: Unpack[InstructBlipProcessorKwargs],
|
| 69 |
+
) -> BatchFeature:
|
| 70 |
+
if images is None and text is None:
|
| 71 |
+
raise ValueError("You have to specify at least images or text.")
|
| 72 |
+
|
| 73 |
+
output_kwargs = self._merge_kwargs(
|
| 74 |
+
InstructBlipProcessorKwargs,
|
| 75 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 76 |
+
**kwargs,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 80 |
+
encoding = {}
|
| 81 |
+
if text is not None:
|
| 82 |
+
if isinstance(text, str):
|
| 83 |
+
text = [text]
|
| 84 |
+
elif not isinstance(text, list) and not isinstance(text[0], str):
|
| 85 |
+
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
| 86 |
+
|
| 87 |
+
qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"])
|
| 88 |
+
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
|
| 89 |
+
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
| 90 |
+
|
| 91 |
+
# We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token
|
| 92 |
+
if output_kwargs["text_kwargs"].get("max_length") is not None:
|
| 93 |
+
output_kwargs["text_kwargs"]["max_length"] -= self.num_query_tokens
|
| 94 |
+
text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 95 |
+
|
| 96 |
+
if images is not None:
|
| 97 |
+
# Image tokens should not be padded/truncated or prepended with special BOS token
|
| 98 |
+
image_tokens = self.image_token.content * self.num_query_tokens
|
| 99 |
+
output_kwargs["text_kwargs"]["add_special_tokens"] = False
|
| 100 |
+
output_kwargs["text_kwargs"]["padding"] = False
|
| 101 |
+
output_kwargs["text_kwargs"]["truncation"] = False
|
| 102 |
+
image_text_encoding = self.tokenizer(image_tokens, **output_kwargs["text_kwargs"])
|
| 103 |
+
for k in text_encoding:
|
| 104 |
+
text_encoding[k] = [image_text_encoding[k] + sample for sample in text_encoding[k]]
|
| 105 |
+
encoding.update(text_encoding)
|
| 106 |
+
|
| 107 |
+
if images is not None:
|
| 108 |
+
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
|
| 109 |
+
encoding.update(image_encoding)
|
| 110 |
+
|
| 111 |
+
# Cast to desired return tensors type
|
| 112 |
+
encoding = BatchFeature(encoding, tensor_type=return_tensors)
|
| 113 |
+
return encoding
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def model_input_names(self):
|
| 117 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 118 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 119 |
+
qformer_input_names = ["qformer_input_ids", "qformer_attention_mask"]
|
| 120 |
+
return tokenizer_input_names + image_processor_input_names + qformer_input_names
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
__all__ = ["InstructBlipProcessor"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mllama/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_mllama import *
|
| 22 |
+
from .image_processing_mllama import *
|
| 23 |
+
from .image_processing_pil_mllama import *
|
| 24 |
+
from .modeling_mllama import *
|
| 25 |
+
from .processing_mllama import *
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py
ADDED
|
@@ -0,0 +1,963 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
|
| 16 |
+
"""PyTorch MobileViT model."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn
|
| 22 |
+
from torch.nn import CrossEntropyLoss
|
| 23 |
+
|
| 24 |
+
from ... import initialization as init
|
| 25 |
+
from ...activations import ACT2FN
|
| 26 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 27 |
+
from ...modeling_outputs import (
|
| 28 |
+
BaseModelOutputWithNoAttention,
|
| 29 |
+
BaseModelOutputWithPoolingAndNoAttention,
|
| 30 |
+
ImageClassifierOutputWithNoAttention,
|
| 31 |
+
SemanticSegmenterOutput,
|
| 32 |
+
)
|
| 33 |
+
from ...modeling_utils import PreTrainedModel
|
| 34 |
+
from ...utils import auto_docstring, logging, torch_int
|
| 35 |
+
from .configuration_mobilevit import MobileViTConfig
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def make_divisible(value: int, divisor: int = 8, min_value: int | None = None) -> int:
|
| 42 |
+
"""
|
| 43 |
+
Ensure that all layers have a channel count that is divisible by `divisor`.
|
| 44 |
+
"""
|
| 45 |
+
if min_value is None:
|
| 46 |
+
min_value = divisor
|
| 47 |
+
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
|
| 48 |
+
# Make sure that round down does not go down by more than 10%.
|
| 49 |
+
if new_value < 0.9 * value:
|
| 50 |
+
new_value += divisor
|
| 51 |
+
return int(new_value)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MobileViTConvLayer(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
config: MobileViTConfig,
|
| 58 |
+
in_channels: int,
|
| 59 |
+
out_channels: int,
|
| 60 |
+
kernel_size: int,
|
| 61 |
+
stride: int = 1,
|
| 62 |
+
groups: int = 1,
|
| 63 |
+
bias: bool = False,
|
| 64 |
+
dilation: int = 1,
|
| 65 |
+
use_normalization: bool = True,
|
| 66 |
+
use_activation: bool | str = True,
|
| 67 |
+
) -> None:
|
| 68 |
+
super().__init__()
|
| 69 |
+
padding = int((kernel_size - 1) / 2) * dilation
|
| 70 |
+
|
| 71 |
+
if in_channels % groups != 0:
|
| 72 |
+
raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
|
| 73 |
+
if out_channels % groups != 0:
|
| 74 |
+
raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
|
| 75 |
+
|
| 76 |
+
self.convolution = nn.Conv2d(
|
| 77 |
+
in_channels=in_channels,
|
| 78 |
+
out_channels=out_channels,
|
| 79 |
+
kernel_size=kernel_size,
|
| 80 |
+
stride=stride,
|
| 81 |
+
padding=padding,
|
| 82 |
+
dilation=dilation,
|
| 83 |
+
groups=groups,
|
| 84 |
+
bias=bias,
|
| 85 |
+
padding_mode="zeros",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if use_normalization:
|
| 89 |
+
self.normalization = nn.BatchNorm2d(
|
| 90 |
+
num_features=out_channels,
|
| 91 |
+
eps=1e-5,
|
| 92 |
+
momentum=0.1,
|
| 93 |
+
affine=True,
|
| 94 |
+
track_running_stats=True,
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
self.normalization = None
|
| 98 |
+
|
| 99 |
+
if use_activation:
|
| 100 |
+
if isinstance(use_activation, str):
|
| 101 |
+
self.activation = ACT2FN[use_activation]
|
| 102 |
+
elif isinstance(config.hidden_act, str):
|
| 103 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 104 |
+
else:
|
| 105 |
+
self.activation = config.hidden_act
|
| 106 |
+
else:
|
| 107 |
+
self.activation = None
|
| 108 |
+
|
| 109 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
features = self.convolution(features)
|
| 111 |
+
if self.normalization is not None:
|
| 112 |
+
features = self.normalization(features)
|
| 113 |
+
if self.activation is not None:
|
| 114 |
+
features = self.activation(features)
|
| 115 |
+
return features
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class MobileViTInvertedResidual(nn.Module):
|
| 119 |
+
"""
|
| 120 |
+
Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(
|
| 124 |
+
self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1
|
| 125 |
+
) -> None:
|
| 126 |
+
super().__init__()
|
| 127 |
+
expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
|
| 128 |
+
|
| 129 |
+
if stride not in [1, 2]:
|
| 130 |
+
raise ValueError(f"Invalid stride {stride}.")
|
| 131 |
+
|
| 132 |
+
self.use_residual = (stride == 1) and (in_channels == out_channels)
|
| 133 |
+
|
| 134 |
+
self.expand_1x1 = MobileViTConvLayer(
|
| 135 |
+
config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.conv_3x3 = MobileViTConvLayer(
|
| 139 |
+
config,
|
| 140 |
+
in_channels=expanded_channels,
|
| 141 |
+
out_channels=expanded_channels,
|
| 142 |
+
kernel_size=3,
|
| 143 |
+
stride=stride,
|
| 144 |
+
groups=expanded_channels,
|
| 145 |
+
dilation=dilation,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.reduce_1x1 = MobileViTConvLayer(
|
| 149 |
+
config,
|
| 150 |
+
in_channels=expanded_channels,
|
| 151 |
+
out_channels=out_channels,
|
| 152 |
+
kernel_size=1,
|
| 153 |
+
use_activation=False,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
residual = features
|
| 158 |
+
|
| 159 |
+
features = self.expand_1x1(features)
|
| 160 |
+
features = self.conv_3x3(features)
|
| 161 |
+
features = self.reduce_1x1(features)
|
| 162 |
+
|
| 163 |
+
return residual + features if self.use_residual else features
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class MobileViTMobileNetLayer(nn.Module):
|
| 167 |
+
def __init__(
|
| 168 |
+
self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
|
| 169 |
+
) -> None:
|
| 170 |
+
super().__init__()
|
| 171 |
+
|
| 172 |
+
self.layer = nn.ModuleList()
|
| 173 |
+
for i in range(num_stages):
|
| 174 |
+
layer = MobileViTInvertedResidual(
|
| 175 |
+
config,
|
| 176 |
+
in_channels=in_channels,
|
| 177 |
+
out_channels=out_channels,
|
| 178 |
+
stride=stride if i == 0 else 1,
|
| 179 |
+
)
|
| 180 |
+
self.layer.append(layer)
|
| 181 |
+
in_channels = out_channels
|
| 182 |
+
|
| 183 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
for layer_module in self.layer:
|
| 185 |
+
features = layer_module(features)
|
| 186 |
+
return features
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class MobileViTSelfAttention(nn.Module):
|
| 190 |
+
def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
|
| 191 |
+
super().__init__()
|
| 192 |
+
|
| 193 |
+
if hidden_size % config.num_attention_heads != 0:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"The hidden size {hidden_size} is not a multiple of the number of attention "
|
| 196 |
+
f"heads {config.num_attention_heads}."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
self.num_attention_heads = config.num_attention_heads
|
| 200 |
+
self.attention_head_size = int(hidden_size / config.num_attention_heads)
|
| 201 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 202 |
+
|
| 203 |
+
self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 204 |
+
self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 205 |
+
self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 206 |
+
|
| 207 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 208 |
+
|
| 209 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 210 |
+
input_shape = hidden_states.shape[:-1]
|
| 211 |
+
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
| 212 |
+
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 213 |
+
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 214 |
+
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 215 |
+
|
| 216 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 217 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 218 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 219 |
+
|
| 220 |
+
# Normalize the attention scores to probabilities.
|
| 221 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 222 |
+
|
| 223 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 224 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 225 |
+
attention_probs = self.dropout(attention_probs)
|
| 226 |
+
|
| 227 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 228 |
+
|
| 229 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 230 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 231 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 232 |
+
return context_layer
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class MobileViTSelfOutput(nn.Module):
|
| 236 |
+
def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
|
| 237 |
+
super().__init__()
|
| 238 |
+
self.dense = nn.Linear(hidden_size, hidden_size)
|
| 239 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 240 |
+
|
| 241 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 242 |
+
hidden_states = self.dense(hidden_states)
|
| 243 |
+
hidden_states = self.dropout(hidden_states)
|
| 244 |
+
return hidden_states
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class MobileViTAttention(nn.Module):
|
| 248 |
+
def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.attention = MobileViTSelfAttention(config, hidden_size)
|
| 251 |
+
self.output = MobileViTSelfOutput(config, hidden_size)
|
| 252 |
+
|
| 253 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 254 |
+
self_outputs = self.attention(hidden_states)
|
| 255 |
+
attention_output = self.output(self_outputs)
|
| 256 |
+
return attention_output
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class MobileViTIntermediate(nn.Module):
|
| 260 |
+
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.dense = nn.Linear(hidden_size, intermediate_size)
|
| 263 |
+
if isinstance(config.hidden_act, str):
|
| 264 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 265 |
+
else:
|
| 266 |
+
self.intermediate_act_fn = config.hidden_act
|
| 267 |
+
|
| 268 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 269 |
+
hidden_states = self.dense(hidden_states)
|
| 270 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 271 |
+
return hidden_states
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class MobileViTOutput(nn.Module):
|
| 275 |
+
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.dense = nn.Linear(intermediate_size, hidden_size)
|
| 278 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 279 |
+
|
| 280 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 281 |
+
hidden_states = self.dense(hidden_states)
|
| 282 |
+
hidden_states = self.dropout(hidden_states)
|
| 283 |
+
hidden_states = hidden_states + input_tensor
|
| 284 |
+
return hidden_states
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class MobileViTTransformerLayer(nn.Module):
|
| 288 |
+
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.attention = MobileViTAttention(config, hidden_size)
|
| 291 |
+
self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size)
|
| 292 |
+
self.output = MobileViTOutput(config, hidden_size, intermediate_size)
|
| 293 |
+
self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
| 294 |
+
self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
| 295 |
+
|
| 296 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 297 |
+
attention_output = self.attention(self.layernorm_before(hidden_states))
|
| 298 |
+
hidden_states = attention_output + hidden_states
|
| 299 |
+
|
| 300 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 301 |
+
layer_output = self.intermediate(layer_output)
|
| 302 |
+
layer_output = self.output(layer_output, hidden_states)
|
| 303 |
+
return layer_output
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class MobileViTTransformer(nn.Module):
|
| 307 |
+
def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None:
|
| 308 |
+
super().__init__()
|
| 309 |
+
|
| 310 |
+
self.layer = nn.ModuleList()
|
| 311 |
+
for _ in range(num_stages):
|
| 312 |
+
transformer_layer = MobileViTTransformerLayer(
|
| 313 |
+
config,
|
| 314 |
+
hidden_size=hidden_size,
|
| 315 |
+
intermediate_size=int(hidden_size * config.mlp_ratio),
|
| 316 |
+
)
|
| 317 |
+
self.layer.append(transformer_layer)
|
| 318 |
+
|
| 319 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 320 |
+
for layer_module in self.layer:
|
| 321 |
+
hidden_states = layer_module(hidden_states)
|
| 322 |
+
return hidden_states
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class MobileViTLayer(GradientCheckpointingLayer):
|
| 326 |
+
"""
|
| 327 |
+
MobileViT block: https://huggingface.co/papers/2110.02178
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
def __init__(
|
| 331 |
+
self,
|
| 332 |
+
config: MobileViTConfig,
|
| 333 |
+
in_channels: int,
|
| 334 |
+
out_channels: int,
|
| 335 |
+
stride: int,
|
| 336 |
+
hidden_size: int,
|
| 337 |
+
num_stages: int,
|
| 338 |
+
dilation: int = 1,
|
| 339 |
+
) -> None:
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.patch_width = config.patch_size
|
| 342 |
+
self.patch_height = config.patch_size
|
| 343 |
+
|
| 344 |
+
if stride == 2:
|
| 345 |
+
self.downsampling_layer = MobileViTInvertedResidual(
|
| 346 |
+
config,
|
| 347 |
+
in_channels=in_channels,
|
| 348 |
+
out_channels=out_channels,
|
| 349 |
+
stride=stride if dilation == 1 else 1,
|
| 350 |
+
dilation=dilation // 2 if dilation > 1 else 1,
|
| 351 |
+
)
|
| 352 |
+
in_channels = out_channels
|
| 353 |
+
else:
|
| 354 |
+
self.downsampling_layer = None
|
| 355 |
+
|
| 356 |
+
self.conv_kxk = MobileViTConvLayer(
|
| 357 |
+
config,
|
| 358 |
+
in_channels=in_channels,
|
| 359 |
+
out_channels=in_channels,
|
| 360 |
+
kernel_size=config.conv_kernel_size,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
self.conv_1x1 = MobileViTConvLayer(
|
| 364 |
+
config,
|
| 365 |
+
in_channels=in_channels,
|
| 366 |
+
out_channels=hidden_size,
|
| 367 |
+
kernel_size=1,
|
| 368 |
+
use_normalization=False,
|
| 369 |
+
use_activation=False,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
self.transformer = MobileViTTransformer(
|
| 373 |
+
config,
|
| 374 |
+
hidden_size=hidden_size,
|
| 375 |
+
num_stages=num_stages,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
| 379 |
+
|
| 380 |
+
self.conv_projection = MobileViTConvLayer(
|
| 381 |
+
config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
self.fusion = MobileViTConvLayer(
|
| 385 |
+
config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
def unfolding(self, features: torch.Tensor) -> tuple[torch.Tensor, dict]:
|
| 389 |
+
patch_width, patch_height = self.patch_width, self.patch_height
|
| 390 |
+
patch_area = int(patch_width * patch_height)
|
| 391 |
+
|
| 392 |
+
batch_size, channels, orig_height, orig_width = features.shape
|
| 393 |
+
|
| 394 |
+
new_height = (
|
| 395 |
+
torch_int(torch.ceil(orig_height / patch_height) * patch_height)
|
| 396 |
+
if torch.jit.is_tracing()
|
| 397 |
+
else int(math.ceil(orig_height / patch_height) * patch_height)
|
| 398 |
+
)
|
| 399 |
+
new_width = (
|
| 400 |
+
torch_int(torch.ceil(orig_width / patch_width) * patch_width)
|
| 401 |
+
if torch.jit.is_tracing()
|
| 402 |
+
else int(math.ceil(orig_width / patch_width) * patch_width)
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
interpolate = False
|
| 406 |
+
if new_width != orig_width or new_height != orig_height:
|
| 407 |
+
# Note: Padding can be done, but then it needs to be handled in attention function.
|
| 408 |
+
features = nn.functional.interpolate(
|
| 409 |
+
features, size=(new_height, new_width), mode="bilinear", align_corners=False
|
| 410 |
+
)
|
| 411 |
+
interpolate = True
|
| 412 |
+
|
| 413 |
+
# number of patches along width and height
|
| 414 |
+
num_patch_width = new_width // patch_width
|
| 415 |
+
num_patch_height = new_height // patch_height
|
| 416 |
+
num_patches = num_patch_height * num_patch_width
|
| 417 |
+
|
| 418 |
+
# convert from shape (batch_size, channels, orig_height, orig_width)
|
| 419 |
+
# to the shape (batch_size * patch_area, num_patches, channels)
|
| 420 |
+
patches = features.reshape(
|
| 421 |
+
batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width
|
| 422 |
+
)
|
| 423 |
+
patches = patches.transpose(1, 2)
|
| 424 |
+
patches = patches.reshape(batch_size, channels, num_patches, patch_area)
|
| 425 |
+
patches = patches.transpose(1, 3)
|
| 426 |
+
patches = patches.reshape(batch_size * patch_area, num_patches, -1)
|
| 427 |
+
|
| 428 |
+
info_dict = {
|
| 429 |
+
"orig_size": (orig_height, orig_width),
|
| 430 |
+
"batch_size": batch_size,
|
| 431 |
+
"channels": channels,
|
| 432 |
+
"interpolate": interpolate,
|
| 433 |
+
"num_patches": num_patches,
|
| 434 |
+
"num_patches_width": num_patch_width,
|
| 435 |
+
"num_patches_height": num_patch_height,
|
| 436 |
+
}
|
| 437 |
+
return patches, info_dict
|
| 438 |
+
|
| 439 |
+
def folding(self, patches: torch.Tensor, info_dict: dict) -> torch.Tensor:
|
| 440 |
+
patch_width, patch_height = self.patch_width, self.patch_height
|
| 441 |
+
patch_area = int(patch_width * patch_height)
|
| 442 |
+
|
| 443 |
+
batch_size = info_dict["batch_size"]
|
| 444 |
+
channels = info_dict["channels"]
|
| 445 |
+
num_patches = info_dict["num_patches"]
|
| 446 |
+
num_patch_height = info_dict["num_patches_height"]
|
| 447 |
+
num_patch_width = info_dict["num_patches_width"]
|
| 448 |
+
|
| 449 |
+
# convert from shape (batch_size * patch_area, num_patches, channels)
|
| 450 |
+
# back to shape (batch_size, channels, orig_height, orig_width)
|
| 451 |
+
features = patches.contiguous().view(batch_size, patch_area, num_patches, -1)
|
| 452 |
+
features = features.transpose(1, 3)
|
| 453 |
+
features = features.reshape(
|
| 454 |
+
batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width
|
| 455 |
+
)
|
| 456 |
+
features = features.transpose(1, 2)
|
| 457 |
+
features = features.reshape(
|
| 458 |
+
batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if info_dict["interpolate"]:
|
| 462 |
+
features = nn.functional.interpolate(
|
| 463 |
+
features, size=info_dict["orig_size"], mode="bilinear", align_corners=False
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
return features
|
| 467 |
+
|
| 468 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 469 |
+
# reduce spatial dimensions if needed
|
| 470 |
+
if self.downsampling_layer:
|
| 471 |
+
features = self.downsampling_layer(features)
|
| 472 |
+
|
| 473 |
+
residual = features
|
| 474 |
+
|
| 475 |
+
# local representation
|
| 476 |
+
features = self.conv_kxk(features)
|
| 477 |
+
features = self.conv_1x1(features)
|
| 478 |
+
|
| 479 |
+
# convert feature map to patches
|
| 480 |
+
patches, info_dict = self.unfolding(features)
|
| 481 |
+
|
| 482 |
+
# learn global representations
|
| 483 |
+
patches = self.transformer(patches)
|
| 484 |
+
patches = self.layernorm(patches)
|
| 485 |
+
|
| 486 |
+
# convert patches back to feature maps
|
| 487 |
+
features = self.folding(patches, info_dict)
|
| 488 |
+
|
| 489 |
+
features = self.conv_projection(features)
|
| 490 |
+
features = self.fusion(torch.cat((residual, features), dim=1))
|
| 491 |
+
return features
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class MobileViTEncoder(nn.Module):
|
| 495 |
+
def __init__(self, config: MobileViTConfig) -> None:
|
| 496 |
+
super().__init__()
|
| 497 |
+
self.config = config
|
| 498 |
+
|
| 499 |
+
self.layer = nn.ModuleList()
|
| 500 |
+
self.gradient_checkpointing = False
|
| 501 |
+
|
| 502 |
+
# segmentation architectures like DeepLab and PSPNet modify the strides
|
| 503 |
+
# of the classification backbones
|
| 504 |
+
dilate_layer_4 = dilate_layer_5 = False
|
| 505 |
+
if config.output_stride == 8:
|
| 506 |
+
dilate_layer_4 = True
|
| 507 |
+
dilate_layer_5 = True
|
| 508 |
+
elif config.output_stride == 16:
|
| 509 |
+
dilate_layer_5 = True
|
| 510 |
+
|
| 511 |
+
dilation = 1
|
| 512 |
+
|
| 513 |
+
layer_1 = MobileViTMobileNetLayer(
|
| 514 |
+
config,
|
| 515 |
+
in_channels=config.neck_hidden_sizes[0],
|
| 516 |
+
out_channels=config.neck_hidden_sizes[1],
|
| 517 |
+
stride=1,
|
| 518 |
+
num_stages=1,
|
| 519 |
+
)
|
| 520 |
+
self.layer.append(layer_1)
|
| 521 |
+
|
| 522 |
+
layer_2 = MobileViTMobileNetLayer(
|
| 523 |
+
config,
|
| 524 |
+
in_channels=config.neck_hidden_sizes[1],
|
| 525 |
+
out_channels=config.neck_hidden_sizes[2],
|
| 526 |
+
stride=2,
|
| 527 |
+
num_stages=3,
|
| 528 |
+
)
|
| 529 |
+
self.layer.append(layer_2)
|
| 530 |
+
|
| 531 |
+
layer_3 = MobileViTLayer(
|
| 532 |
+
config,
|
| 533 |
+
in_channels=config.neck_hidden_sizes[2],
|
| 534 |
+
out_channels=config.neck_hidden_sizes[3],
|
| 535 |
+
stride=2,
|
| 536 |
+
hidden_size=config.hidden_sizes[0],
|
| 537 |
+
num_stages=2,
|
| 538 |
+
)
|
| 539 |
+
self.layer.append(layer_3)
|
| 540 |
+
|
| 541 |
+
if dilate_layer_4:
|
| 542 |
+
dilation *= 2
|
| 543 |
+
|
| 544 |
+
layer_4 = MobileViTLayer(
|
| 545 |
+
config,
|
| 546 |
+
in_channels=config.neck_hidden_sizes[3],
|
| 547 |
+
out_channels=config.neck_hidden_sizes[4],
|
| 548 |
+
stride=2,
|
| 549 |
+
hidden_size=config.hidden_sizes[1],
|
| 550 |
+
num_stages=4,
|
| 551 |
+
dilation=dilation,
|
| 552 |
+
)
|
| 553 |
+
self.layer.append(layer_4)
|
| 554 |
+
|
| 555 |
+
if dilate_layer_5:
|
| 556 |
+
dilation *= 2
|
| 557 |
+
|
| 558 |
+
layer_5 = MobileViTLayer(
|
| 559 |
+
config,
|
| 560 |
+
in_channels=config.neck_hidden_sizes[4],
|
| 561 |
+
out_channels=config.neck_hidden_sizes[5],
|
| 562 |
+
stride=2,
|
| 563 |
+
hidden_size=config.hidden_sizes[2],
|
| 564 |
+
num_stages=3,
|
| 565 |
+
dilation=dilation,
|
| 566 |
+
)
|
| 567 |
+
self.layer.append(layer_5)
|
| 568 |
+
|
| 569 |
+
def forward(
|
| 570 |
+
self,
|
| 571 |
+
hidden_states: torch.Tensor,
|
| 572 |
+
output_hidden_states: bool = False,
|
| 573 |
+
return_dict: bool = True,
|
| 574 |
+
) -> tuple | BaseModelOutputWithNoAttention:
|
| 575 |
+
all_hidden_states = () if output_hidden_states else None
|
| 576 |
+
|
| 577 |
+
for i, layer_module in enumerate(self.layer):
|
| 578 |
+
hidden_states = layer_module(hidden_states)
|
| 579 |
+
|
| 580 |
+
if output_hidden_states:
|
| 581 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 582 |
+
|
| 583 |
+
if not return_dict:
|
| 584 |
+
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
| 585 |
+
|
| 586 |
+
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@auto_docstring
|
| 590 |
+
class MobileViTPreTrainedModel(PreTrainedModel):
|
| 591 |
+
config: MobileViTConfig
|
| 592 |
+
base_model_prefix = "mobilevit"
|
| 593 |
+
main_input_name = "pixel_values"
|
| 594 |
+
input_modalities = ("image",)
|
| 595 |
+
supports_gradient_checkpointing = True
|
| 596 |
+
_no_split_modules = ["MobileViTLayer"]
|
| 597 |
+
|
| 598 |
+
@torch.no_grad()
|
| 599 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 600 |
+
"""Initialize the weights"""
|
| 601 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
| 602 |
+
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
| 603 |
+
if module.bias is not None:
|
| 604 |
+
init.zeros_(module.bias)
|
| 605 |
+
if getattr(module, "running_mean", None) is not None:
|
| 606 |
+
init.zeros_(module.running_mean)
|
| 607 |
+
init.ones_(module.running_var)
|
| 608 |
+
init.zeros_(module.num_batches_tracked)
|
| 609 |
+
elif isinstance(module, nn.LayerNorm):
|
| 610 |
+
init.zeros_(module.bias)
|
| 611 |
+
init.ones_(module.weight)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
@auto_docstring
|
| 615 |
+
class MobileViTModel(MobileViTPreTrainedModel):
|
| 616 |
+
def __init__(self, config: MobileViTConfig, expand_output: bool = True):
|
| 617 |
+
r"""
|
| 618 |
+
expand_output (`bool`, *optional*, defaults to `True`):
|
| 619 |
+
Whether to expand the output of the model using a 1x1 convolution. If `True`, the model will apply an additional
|
| 620 |
+
1x1 convolution to expand the output channels from `config.neck_hidden_sizes[5]` to `config.neck_hidden_sizes[6]`.
|
| 621 |
+
"""
|
| 622 |
+
super().__init__(config)
|
| 623 |
+
self.config = config
|
| 624 |
+
self.expand_output = expand_output
|
| 625 |
+
|
| 626 |
+
self.conv_stem = MobileViTConvLayer(
|
| 627 |
+
config,
|
| 628 |
+
in_channels=config.num_channels,
|
| 629 |
+
out_channels=config.neck_hidden_sizes[0],
|
| 630 |
+
kernel_size=3,
|
| 631 |
+
stride=2,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
self.encoder = MobileViTEncoder(config)
|
| 635 |
+
|
| 636 |
+
if self.expand_output:
|
| 637 |
+
self.conv_1x1_exp = MobileViTConvLayer(
|
| 638 |
+
config,
|
| 639 |
+
in_channels=config.neck_hidden_sizes[5],
|
| 640 |
+
out_channels=config.neck_hidden_sizes[6],
|
| 641 |
+
kernel_size=1,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# Initialize weights and apply final processing
|
| 645 |
+
self.post_init()
|
| 646 |
+
|
| 647 |
+
@auto_docstring
|
| 648 |
+
def forward(
|
| 649 |
+
self,
|
| 650 |
+
pixel_values: torch.Tensor | None = None,
|
| 651 |
+
output_hidden_states: bool | None = None,
|
| 652 |
+
return_dict: bool | None = None,
|
| 653 |
+
**kwargs,
|
| 654 |
+
) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
|
| 655 |
+
output_hidden_states = (
|
| 656 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 657 |
+
)
|
| 658 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 659 |
+
|
| 660 |
+
if pixel_values is None:
|
| 661 |
+
raise ValueError("You have to specify pixel_values")
|
| 662 |
+
|
| 663 |
+
embedding_output = self.conv_stem(pixel_values)
|
| 664 |
+
|
| 665 |
+
encoder_outputs = self.encoder(
|
| 666 |
+
embedding_output,
|
| 667 |
+
output_hidden_states=output_hidden_states,
|
| 668 |
+
return_dict=return_dict,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
if self.expand_output:
|
| 672 |
+
last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
|
| 673 |
+
|
| 674 |
+
# global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
|
| 675 |
+
pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
|
| 676 |
+
else:
|
| 677 |
+
last_hidden_state = encoder_outputs[0]
|
| 678 |
+
pooled_output = None
|
| 679 |
+
|
| 680 |
+
if not return_dict:
|
| 681 |
+
output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
|
| 682 |
+
return output + encoder_outputs[1:]
|
| 683 |
+
|
| 684 |
+
return BaseModelOutputWithPoolingAndNoAttention(
|
| 685 |
+
last_hidden_state=last_hidden_state,
|
| 686 |
+
pooler_output=pooled_output,
|
| 687 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
@auto_docstring(
|
| 692 |
+
custom_intro="""
|
| 693 |
+
MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
|
| 694 |
+
ImageNet.
|
| 695 |
+
"""
|
| 696 |
+
)
|
| 697 |
+
class MobileViTForImageClassification(MobileViTPreTrainedModel):
|
| 698 |
+
def __init__(self, config: MobileViTConfig) -> None:
|
| 699 |
+
super().__init__(config)
|
| 700 |
+
|
| 701 |
+
self.num_labels = config.num_labels
|
| 702 |
+
self.mobilevit = MobileViTModel(config)
|
| 703 |
+
|
| 704 |
+
# Classifier head
|
| 705 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True)
|
| 706 |
+
self.classifier = (
|
| 707 |
+
nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# Initialize weights and apply final processing
|
| 711 |
+
self.post_init()
|
| 712 |
+
|
| 713 |
+
@auto_docstring
|
| 714 |
+
def forward(
|
| 715 |
+
self,
|
| 716 |
+
pixel_values: torch.Tensor | None = None,
|
| 717 |
+
output_hidden_states: bool | None = None,
|
| 718 |
+
labels: torch.Tensor | None = None,
|
| 719 |
+
return_dict: bool | None = None,
|
| 720 |
+
**kwargs,
|
| 721 |
+
) -> tuple | ImageClassifierOutputWithNoAttention:
|
| 722 |
+
r"""
|
| 723 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 724 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 725 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
|
| 726 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 727 |
+
"""
|
| 728 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 729 |
+
|
| 730 |
+
outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
|
| 731 |
+
|
| 732 |
+
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
| 733 |
+
|
| 734 |
+
logits = self.classifier(self.dropout(pooled_output))
|
| 735 |
+
|
| 736 |
+
loss = None
|
| 737 |
+
if labels is not None:
|
| 738 |
+
loss = self.loss_function(labels, logits, self.config)
|
| 739 |
+
|
| 740 |
+
if not return_dict:
|
| 741 |
+
output = (logits,) + outputs[2:]
|
| 742 |
+
return ((loss,) + output) if loss is not None else output
|
| 743 |
+
|
| 744 |
+
return ImageClassifierOutputWithNoAttention(
|
| 745 |
+
loss=loss,
|
| 746 |
+
logits=logits,
|
| 747 |
+
hidden_states=outputs.hidden_states,
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
class MobileViTASPPPooling(nn.Module):
|
| 752 |
+
def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None:
|
| 753 |
+
super().__init__()
|
| 754 |
+
|
| 755 |
+
self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
|
| 756 |
+
|
| 757 |
+
self.conv_1x1 = MobileViTConvLayer(
|
| 758 |
+
config,
|
| 759 |
+
in_channels=in_channels,
|
| 760 |
+
out_channels=out_channels,
|
| 761 |
+
kernel_size=1,
|
| 762 |
+
stride=1,
|
| 763 |
+
use_normalization=True,
|
| 764 |
+
use_activation="relu",
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 768 |
+
spatial_size = features.shape[-2:]
|
| 769 |
+
features = self.global_pool(features)
|
| 770 |
+
features = self.conv_1x1(features)
|
| 771 |
+
features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
|
| 772 |
+
return features
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
class MobileViTASPP(nn.Module):
|
| 776 |
+
"""
|
| 777 |
+
ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
|
| 778 |
+
"""
|
| 779 |
+
|
| 780 |
+
def __init__(self, config: MobileViTConfig) -> None:
|
| 781 |
+
super().__init__()
|
| 782 |
+
|
| 783 |
+
in_channels = config.neck_hidden_sizes[-2]
|
| 784 |
+
out_channels = config.aspp_out_channels
|
| 785 |
+
|
| 786 |
+
if len(config.atrous_rates) != 3:
|
| 787 |
+
raise ValueError("Expected 3 values for atrous_rates")
|
| 788 |
+
|
| 789 |
+
self.convs = nn.ModuleList()
|
| 790 |
+
|
| 791 |
+
in_projection = MobileViTConvLayer(
|
| 792 |
+
config,
|
| 793 |
+
in_channels=in_channels,
|
| 794 |
+
out_channels=out_channels,
|
| 795 |
+
kernel_size=1,
|
| 796 |
+
use_activation="relu",
|
| 797 |
+
)
|
| 798 |
+
self.convs.append(in_projection)
|
| 799 |
+
|
| 800 |
+
self.convs.extend(
|
| 801 |
+
[
|
| 802 |
+
MobileViTConvLayer(
|
| 803 |
+
config,
|
| 804 |
+
in_channels=in_channels,
|
| 805 |
+
out_channels=out_channels,
|
| 806 |
+
kernel_size=3,
|
| 807 |
+
dilation=rate,
|
| 808 |
+
use_activation="relu",
|
| 809 |
+
)
|
| 810 |
+
for rate in config.atrous_rates
|
| 811 |
+
]
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
pool_layer = MobileViTASPPPooling(config, in_channels, out_channels)
|
| 815 |
+
self.convs.append(pool_layer)
|
| 816 |
+
|
| 817 |
+
self.project = MobileViTConvLayer(
|
| 818 |
+
config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
|
| 822 |
+
|
| 823 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 824 |
+
pyramid = []
|
| 825 |
+
for conv in self.convs:
|
| 826 |
+
pyramid.append(conv(features))
|
| 827 |
+
pyramid = torch.cat(pyramid, dim=1)
|
| 828 |
+
|
| 829 |
+
pooled_features = self.project(pyramid)
|
| 830 |
+
pooled_features = self.dropout(pooled_features)
|
| 831 |
+
return pooled_features
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
class MobileViTDeepLabV3(nn.Module):
|
| 835 |
+
"""
|
| 836 |
+
DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
|
| 837 |
+
"""
|
| 838 |
+
|
| 839 |
+
def __init__(self, config: MobileViTConfig) -> None:
|
| 840 |
+
super().__init__()
|
| 841 |
+
self.aspp = MobileViTASPP(config)
|
| 842 |
+
|
| 843 |
+
self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
|
| 844 |
+
|
| 845 |
+
self.classifier = MobileViTConvLayer(
|
| 846 |
+
config,
|
| 847 |
+
in_channels=config.aspp_out_channels,
|
| 848 |
+
out_channels=config.num_labels,
|
| 849 |
+
kernel_size=1,
|
| 850 |
+
use_normalization=False,
|
| 851 |
+
use_activation=False,
|
| 852 |
+
bias=True,
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 856 |
+
features = self.aspp(hidden_states[-1])
|
| 857 |
+
features = self.dropout(features)
|
| 858 |
+
features = self.classifier(features)
|
| 859 |
+
return features
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
@auto_docstring(
|
| 863 |
+
custom_intro="""
|
| 864 |
+
MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
|
| 865 |
+
"""
|
| 866 |
+
)
|
| 867 |
+
class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
|
| 868 |
+
def __init__(self, config: MobileViTConfig) -> None:
|
| 869 |
+
super().__init__(config)
|
| 870 |
+
|
| 871 |
+
self.num_labels = config.num_labels
|
| 872 |
+
self.mobilevit = MobileViTModel(config, expand_output=False)
|
| 873 |
+
self.segmentation_head = MobileViTDeepLabV3(config)
|
| 874 |
+
|
| 875 |
+
# Initialize weights and apply final processing
|
| 876 |
+
self.post_init()
|
| 877 |
+
|
| 878 |
+
@auto_docstring
|
| 879 |
+
def forward(
|
| 880 |
+
self,
|
| 881 |
+
pixel_values: torch.Tensor | None = None,
|
| 882 |
+
labels: torch.Tensor | None = None,
|
| 883 |
+
output_hidden_states: bool | None = None,
|
| 884 |
+
return_dict: bool | None = None,
|
| 885 |
+
**kwargs,
|
| 886 |
+
) -> tuple | SemanticSegmenterOutput:
|
| 887 |
+
r"""
|
| 888 |
+
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
| 889 |
+
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
| 890 |
+
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
| 891 |
+
|
| 892 |
+
Examples:
|
| 893 |
+
|
| 894 |
+
```python
|
| 895 |
+
>>> import httpx
|
| 896 |
+
>>> from io import BytesIO
|
| 897 |
+
>>> import torch
|
| 898 |
+
>>> from PIL import Image
|
| 899 |
+
>>> from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
|
| 900 |
+
|
| 901 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 902 |
+
>>> with httpx.stream("GET", url) as response:
|
| 903 |
+
... image = Image.open(BytesIO(response.read()))
|
| 904 |
+
|
| 905 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
|
| 906 |
+
>>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
|
| 907 |
+
|
| 908 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 909 |
+
|
| 910 |
+
>>> with torch.no_grad():
|
| 911 |
+
... outputs = model(**inputs)
|
| 912 |
+
|
| 913 |
+
>>> # logits are of shape (batch_size, num_labels, height, width)
|
| 914 |
+
>>> logits = outputs.logits
|
| 915 |
+
```"""
|
| 916 |
+
output_hidden_states = (
|
| 917 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 918 |
+
)
|
| 919 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 920 |
+
|
| 921 |
+
if labels is not None and self.config.num_labels == 1:
|
| 922 |
+
raise ValueError("The number of labels should be greater than one")
|
| 923 |
+
|
| 924 |
+
outputs = self.mobilevit(
|
| 925 |
+
pixel_values,
|
| 926 |
+
output_hidden_states=True, # we need the intermediate hidden states
|
| 927 |
+
return_dict=return_dict,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
| 931 |
+
|
| 932 |
+
logits = self.segmentation_head(encoder_hidden_states)
|
| 933 |
+
|
| 934 |
+
loss = None
|
| 935 |
+
if labels is not None:
|
| 936 |
+
# upsample logits to the images' original size
|
| 937 |
+
upsampled_logits = nn.functional.interpolate(
|
| 938 |
+
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
| 939 |
+
)
|
| 940 |
+
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
| 941 |
+
loss = loss_fct(upsampled_logits, labels)
|
| 942 |
+
|
| 943 |
+
if not return_dict:
|
| 944 |
+
if output_hidden_states:
|
| 945 |
+
output = (logits,) + outputs[1:]
|
| 946 |
+
else:
|
| 947 |
+
output = (logits,) + outputs[2:]
|
| 948 |
+
return ((loss,) + output) if loss is not None else output
|
| 949 |
+
|
| 950 |
+
return SemanticSegmenterOutput(
|
| 951 |
+
loss=loss,
|
| 952 |
+
logits=logits,
|
| 953 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 954 |
+
attentions=None,
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
__all__ = [
|
| 959 |
+
"MobileViTForImageClassification",
|
| 960 |
+
"MobileViTForSemanticSegmentation",
|
| 961 |
+
"MobileViTModel",
|
| 962 |
+
"MobileViTPreTrainedModel",
|
| 963 |
+
]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/configuration_speecht5.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""SpeechT5 model configuration"""
|
| 15 |
+
|
| 16 |
+
import functools
|
| 17 |
+
import operator
|
| 18 |
+
|
| 19 |
+
from huggingface_hub.dataclasses import strict
|
| 20 |
+
|
| 21 |
+
from ...configuration_utils import PreTrainedConfig
|
| 22 |
+
from ...utils import auto_docstring
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@auto_docstring(checkpoint="microsoft/speecht5_asr")
|
| 26 |
+
@strict
|
| 27 |
+
class SpeechT5Config(PreTrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
positional_dropout (`float`, *optional*, defaults to 0.1):
|
| 30 |
+
The dropout probability for the text position encoding layers.
|
| 31 |
+
feat_extract_norm (`str`, *optional*, defaults to `"group"`):
|
| 32 |
+
The norm to be applied to 1D convolutional layers in the speech encoder pre-net. One of `"group"` for group
|
| 33 |
+
normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
|
| 34 |
+
convolutional layers.
|
| 35 |
+
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
|
| 36 |
+
The dropout probability for output of the speech encoder pre-net.
|
| 37 |
+
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
|
| 38 |
+
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
| 39 |
+
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 40 |
+
conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
|
| 41 |
+
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
| 42 |
+
speech encoder pre-net. The length of *conv_dim* defines the number of 1D convolutional layers.
|
| 43 |
+
conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
|
| 44 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the speech encoder pre-net. The
|
| 45 |
+
length of *conv_stride* defines the number of convolutional layers and has to match the length of
|
| 46 |
+
*conv_dim*.
|
| 47 |
+
conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
|
| 48 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the speech encoder pre-net.
|
| 49 |
+
The length of *conv_kernel* defines the number of convolutional layers and has to match the length of
|
| 50 |
+
*conv_dim*.
|
| 51 |
+
conv_bias (`bool`, *optional*, defaults to `False`):
|
| 52 |
+
Whether the 1D convolutional layers have a bias.
|
| 53 |
+
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
| 54 |
+
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
| 55 |
+
embeddings layer.
|
| 56 |
+
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
| 57 |
+
Number of groups of 1D convolutional positional embeddings layer.
|
| 58 |
+
apply_spec_augment (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether to apply *SpecAugment* data augmentation to the outputs of the speech encoder pre-net. For
|
| 60 |
+
reference see [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
| 61 |
+
Recognition](https://huggingface.co/papers/1904.08779).
|
| 62 |
+
mask_time_prob (`float`, *optional*, defaults to 0.05):
|
| 63 |
+
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
|
| 64 |
+
procedure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
|
| 65 |
+
reasoning from the probability of each feature vector to be chosen as the start of the vector span to be
|
| 66 |
+
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
|
| 67 |
+
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
|
| 68 |
+
mask_time_length (`int`, *optional*, defaults to 10):
|
| 69 |
+
Length of vector span along the time axis.
|
| 70 |
+
mask_time_min_masks (`int`, *optional*, defaults to 2),:
|
| 71 |
+
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
|
| 72 |
+
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
|
| 73 |
+
mask_time_min_masks''
|
| 74 |
+
mask_feature_prob (`float`, *optional*, defaults to 0.0):
|
| 75 |
+
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
|
| 76 |
+
masking procedure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
|
| 77 |
+
the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector
|
| 78 |
+
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
|
| 79 |
+
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
|
| 80 |
+
True`.
|
| 81 |
+
mask_feature_length (`int`, *optional*, defaults to 10):
|
| 82 |
+
Length of vector span along the feature axis.
|
| 83 |
+
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
|
| 84 |
+
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
| 85 |
+
step, irrespectively of `mask_feature_prob`. Only relevant if
|
| 86 |
+
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
|
| 87 |
+
num_mel_bins (`int`, *optional*, defaults to 80):
|
| 88 |
+
Number of mel features used per input features. Used by the speech decoder pre-net. Should correspond to
|
| 89 |
+
the value used in the [`SpeechT5Processor`] class.
|
| 90 |
+
speech_decoder_prenet_layers (`int`, *optional*, defaults to 2):
|
| 91 |
+
Number of layers in the speech decoder pre-net.
|
| 92 |
+
speech_decoder_prenet_units (`int`, *optional*, defaults to 256):
|
| 93 |
+
Dimensionality of the layers in the speech decoder pre-net.
|
| 94 |
+
speech_decoder_prenet_dropout (`float`, *optional*, defaults to 0.5):
|
| 95 |
+
The dropout probability for the speech decoder pre-net layers.
|
| 96 |
+
speaker_embedding_dim (`int`, *optional*, defaults to 512):
|
| 97 |
+
Dimensionality of the *XVector* embedding vectors.
|
| 98 |
+
speech_decoder_postnet_layers (`int`, *optional*, defaults to 5):
|
| 99 |
+
Number of layers in the speech decoder post-net.
|
| 100 |
+
speech_decoder_postnet_units (`int`, *optional*, defaults to 256):
|
| 101 |
+
Dimensionality of the layers in the speech decoder post-net.
|
| 102 |
+
speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5):
|
| 103 |
+
Number of convolutional filter channels in the speech decoder post-net.
|
| 104 |
+
speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
|
| 105 |
+
The dropout probability for the speech decoder post-net layers.
|
| 106 |
+
reduction_factor (`int`, *optional*, defaults to 2):
|
| 107 |
+
Spectrogram length reduction factor for the speech decoder inputs.
|
| 108 |
+
max_speech_positions (`int`, *optional*, defaults to 4000):
|
| 109 |
+
The maximum sequence length of speech features that this model might ever be used with.
|
| 110 |
+
max_text_positions (`int`, *optional*, defaults to 450):
|
| 111 |
+
The maximum sequence length of text features that this model might ever be used with.
|
| 112 |
+
encoder_max_relative_position (`int`, *optional*, defaults to 160):
|
| 113 |
+
Maximum distance for relative position embedding in the encoder.
|
| 114 |
+
use_guided_attention_loss (`bool`, *optional*, defaults to `True`):
|
| 115 |
+
Whether to apply guided attention loss while training the TTS model.
|
| 116 |
+
guided_attention_loss_num_heads (`int`, *optional*, defaults to 2):
|
| 117 |
+
Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all
|
| 118 |
+
attention heads.
|
| 119 |
+
guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4):
|
| 120 |
+
Standard deviation for guided attention loss.
|
| 121 |
+
guided_attention_loss_scale (`float`, *optional*, defaults to 10.0):
|
| 122 |
+
Scaling coefficient for guided attention loss (also known as lambda).
|
| 123 |
+
|
| 124 |
+
Example:
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
>>> from transformers import SpeechT5Model, SpeechT5Config
|
| 128 |
+
|
| 129 |
+
>>> # Initializing a "microsoft/speecht5_asr" style configuration
|
| 130 |
+
>>> configuration = SpeechT5Config()
|
| 131 |
+
|
| 132 |
+
>>> # Initializing a model (with random weights) from the "microsoft/speecht5_asr" style configuration
|
| 133 |
+
>>> model = SpeechT5Model(configuration)
|
| 134 |
+
|
| 135 |
+
>>> # Accessing the model configuration
|
| 136 |
+
>>> configuration = model.config
|
| 137 |
+
```"""
|
| 138 |
+
|
| 139 |
+
model_type = "speecht5"
|
| 140 |
+
attribute_map = {"num_attention_heads": "encoder_attention_heads", "num_hidden_layers": "encoder_layers"}
|
| 141 |
+
|
| 142 |
+
vocab_size: int = 81
|
| 143 |
+
hidden_size: int = 768
|
| 144 |
+
encoder_layers: int = 12
|
| 145 |
+
encoder_attention_heads: int = 12
|
| 146 |
+
encoder_ffn_dim: int = 3072
|
| 147 |
+
encoder_layerdrop: float | int = 0.1
|
| 148 |
+
decoder_layers: int = 6
|
| 149 |
+
decoder_ffn_dim: int = 3072
|
| 150 |
+
decoder_attention_heads: int = 12
|
| 151 |
+
decoder_layerdrop: float | int = 0.1
|
| 152 |
+
hidden_act: str = "gelu"
|
| 153 |
+
positional_dropout: float | int = 0.1
|
| 154 |
+
hidden_dropout: float | int = 0.1
|
| 155 |
+
attention_dropout: float | int = 0.1
|
| 156 |
+
activation_dropout: float | int = 0.1
|
| 157 |
+
initializer_range: float = 0.02
|
| 158 |
+
layer_norm_eps: float = 1e-5
|
| 159 |
+
scale_embedding: bool = False
|
| 160 |
+
feat_extract_norm: str = "group"
|
| 161 |
+
feat_proj_dropout: float | int = 0.0
|
| 162 |
+
feat_extract_activation: str = "gelu"
|
| 163 |
+
conv_dim: list[int] | tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512)
|
| 164 |
+
conv_stride: list[int] | tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2)
|
| 165 |
+
conv_kernel: list[int] | tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2)
|
| 166 |
+
conv_bias: bool = False
|
| 167 |
+
num_conv_pos_embeddings: int = 128
|
| 168 |
+
num_conv_pos_embedding_groups: int = 16
|
| 169 |
+
apply_spec_augment: bool = True
|
| 170 |
+
mask_time_prob: float | int = 0.05
|
| 171 |
+
mask_time_length: int = 10
|
| 172 |
+
mask_time_min_masks: int = 2
|
| 173 |
+
mask_feature_prob: float | int = 0.0
|
| 174 |
+
mask_feature_length: int = 10
|
| 175 |
+
mask_feature_min_masks: int = 0
|
| 176 |
+
pad_token_id: int | None = 1
|
| 177 |
+
bos_token_id: int | None = 0
|
| 178 |
+
eos_token_id: int | list[int] | None = 2
|
| 179 |
+
decoder_start_token_id: int | None = 2
|
| 180 |
+
num_mel_bins: int = 80
|
| 181 |
+
speech_decoder_prenet_layers: int = 2
|
| 182 |
+
speech_decoder_prenet_units: int = 256
|
| 183 |
+
speech_decoder_prenet_dropout: float | int = 0.5
|
| 184 |
+
speaker_embedding_dim: int = 512
|
| 185 |
+
speech_decoder_postnet_layers: int = 5
|
| 186 |
+
speech_decoder_postnet_units: int = 256
|
| 187 |
+
speech_decoder_postnet_kernel: int = 5
|
| 188 |
+
speech_decoder_postnet_dropout: float | int = 0.5
|
| 189 |
+
reduction_factor: int = 2
|
| 190 |
+
max_speech_positions: int = 4000
|
| 191 |
+
max_text_positions: int = 450
|
| 192 |
+
encoder_max_relative_position: int = 160
|
| 193 |
+
use_guided_attention_loss: bool = True
|
| 194 |
+
guided_attention_loss_num_heads: int = 2
|
| 195 |
+
guided_attention_loss_sigma: float = 0.4
|
| 196 |
+
guided_attention_loss_scale: float = 10.0
|
| 197 |
+
use_cache: bool = True
|
| 198 |
+
is_encoder_decoder: bool = True
|
| 199 |
+
tie_word_embeddings: bool = True
|
| 200 |
+
|
| 201 |
+
def __post_init__(self, **kwargs):
|
| 202 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
| 203 |
+
super().__post_init__(**kwargs)
|
| 204 |
+
|
| 205 |
+
def validate_architecture(self):
|
| 206 |
+
"""Part of `@strict`-powered validation. Validates the architecture of the config."""
|
| 207 |
+
if (
|
| 208 |
+
(len(self.conv_stride) != self.num_feat_extract_layers)
|
| 209 |
+
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
| 210 |
+
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
| 211 |
+
):
|
| 212 |
+
raise ValueError(
|
| 213 |
+
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
|
| 214 |
+
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
|
| 215 |
+
f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
|
| 216 |
+
f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def inputs_to_logits_ratio(self):
|
| 220 |
+
return functools.reduce(operator.mul, self.conv_stride, 1)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@auto_docstring(checkpoint="microsoft/speecht5_asr")
|
| 224 |
+
@strict
|
| 225 |
+
class SpeechT5HifiGanConfig(PreTrainedConfig):
|
| 226 |
+
r"""
|
| 227 |
+
model_in_dim (`int`, *optional*, defaults to 80):
|
| 228 |
+
The number of frequency bins in the input log-mel spectrogram.
|
| 229 |
+
upsample_initial_channel (`int`, *optional*, defaults to 512):
|
| 230 |
+
The number of input channels into the upsampling network.
|
| 231 |
+
upsample_rates (`tuple[int]` or `list[int]`, *optional*, defaults to `[4, 4, 4, 4]`):
|
| 232 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The
|
| 233 |
+
length of *upsample_rates* defines the number of convolutional layers and has to match the length of
|
| 234 |
+
*upsample_kernel_sizes*.
|
| 235 |
+
upsample_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[8, 8, 8, 8]`):
|
| 236 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The
|
| 237 |
+
length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of
|
| 238 |
+
*upsample_rates*.
|
| 239 |
+
resblock_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[3, 7, 11]`):
|
| 240 |
+
A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field
|
| 241 |
+
fusion (MRF) module.
|
| 242 |
+
resblock_dilation_sizes (`tuple[tuple[int]]` or `list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
|
| 243 |
+
A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the
|
| 244 |
+
multi-receptive field fusion (MRF) module.
|
| 245 |
+
leaky_relu_slope (`float`, *optional*, defaults to 0.1):
|
| 246 |
+
The angle of the negative slope used by the leaky ReLU activation.
|
| 247 |
+
normalize_before (`bool`, *optional*, defaults to `True`):
|
| 248 |
+
Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance.
|
| 249 |
+
|
| 250 |
+
Example:
|
| 251 |
+
|
| 252 |
+
```python
|
| 253 |
+
>>> from transformers import SpeechT5HifiGan, SpeechT5HifiGanConfig
|
| 254 |
+
|
| 255 |
+
>>> # Initializing a "microsoft/speecht5_hifigan" style configuration
|
| 256 |
+
>>> configuration = SpeechT5HifiGanConfig()
|
| 257 |
+
|
| 258 |
+
>>> # Initializing a model (with random weights) from the "microsoft/speecht5_hifigan" style configuration
|
| 259 |
+
>>> model = SpeechT5HifiGan(configuration)
|
| 260 |
+
|
| 261 |
+
>>> # Accessing the model configuration
|
| 262 |
+
>>> configuration = model.config
|
| 263 |
+
```"""
|
| 264 |
+
|
| 265 |
+
model_type = "speecht5_hifigan"
|
| 266 |
+
|
| 267 |
+
model_in_dim: int = 80
|
| 268 |
+
sampling_rate: int = 16000
|
| 269 |
+
upsample_initial_channel: int = 512
|
| 270 |
+
upsample_rates: list[int] | tuple[int, ...] = (4, 4, 4, 4)
|
| 271 |
+
upsample_kernel_sizes: list[int] | tuple[int, ...] = (8, 8, 8, 8)
|
| 272 |
+
resblock_kernel_sizes: list[int] | tuple[int, ...] = (3, 7, 11)
|
| 273 |
+
resblock_dilation_sizes: list | tuple = ((1, 3, 5), (1, 3, 5), (1, 3, 5))
|
| 274 |
+
initializer_range: float = 0.01
|
| 275 |
+
leaky_relu_slope: float = 0.1
|
| 276 |
+
normalize_before: bool = True
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
__all__ = ["SpeechT5Config", "SpeechT5HifiGanConfig"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/modeling_speecht5.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/number_normalizer.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Number Normalizer class for SpeechT5."""
|
| 15 |
+
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class EnglishNumberNormalizer:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
| 22 |
+
self.teens = [
|
| 23 |
+
"",
|
| 24 |
+
"eleven",
|
| 25 |
+
"twelve",
|
| 26 |
+
"thirteen",
|
| 27 |
+
"fourteen",
|
| 28 |
+
"fifteen",
|
| 29 |
+
"sixteen",
|
| 30 |
+
"seventeen",
|
| 31 |
+
"eighteen",
|
| 32 |
+
"nineteen",
|
| 33 |
+
]
|
| 34 |
+
self.tens = ["", "ten", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
|
| 35 |
+
self.thousands = [
|
| 36 |
+
"",
|
| 37 |
+
"thousand",
|
| 38 |
+
"million",
|
| 39 |
+
"billion",
|
| 40 |
+
"trillion",
|
| 41 |
+
"quadrillion",
|
| 42 |
+
"quintillion",
|
| 43 |
+
"sextillion",
|
| 44 |
+
"septillion",
|
| 45 |
+
"octillion",
|
| 46 |
+
"nonillion",
|
| 47 |
+
"decillion",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
# Define a dictionary to map currency symbols to their names
|
| 51 |
+
# Top most traded currencies according to
|
| 52 |
+
# https://en.wikipedia.org/wiki/Template:Most_traded_currencies
|
| 53 |
+
self.currency_symbols = {
|
| 54 |
+
"$": " dollars",
|
| 55 |
+
"€": " euros",
|
| 56 |
+
"£": " pounds",
|
| 57 |
+
"¢": " cents",
|
| 58 |
+
"¥": " japanese yen",
|
| 59 |
+
"﷼": " saudi riyal",
|
| 60 |
+
"₹": " indian rupees",
|
| 61 |
+
"₽": " russian rubles",
|
| 62 |
+
"฿": " thai baht",
|
| 63 |
+
"₺": " turkish liras",
|
| 64 |
+
"₴": " ukrainian hryvnia",
|
| 65 |
+
"₣": " swiss francs",
|
| 66 |
+
"₡": " costa rican colon",
|
| 67 |
+
"₱": " philippine peso",
|
| 68 |
+
"₪": " israeli shekels",
|
| 69 |
+
"₮": " mongolian tögrög",
|
| 70 |
+
"₩": " south korean won",
|
| 71 |
+
"₦": " nigerian naira",
|
| 72 |
+
"₫": " vietnamese Đồng",
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
def spell_number(self, num):
|
| 76 |
+
if num == 0:
|
| 77 |
+
return "zero"
|
| 78 |
+
|
| 79 |
+
parts = []
|
| 80 |
+
for i in range(0, len(self.thousands)):
|
| 81 |
+
if num % 1000 != 0:
|
| 82 |
+
part = ""
|
| 83 |
+
hundreds = num % 1000 // 100
|
| 84 |
+
tens_units = num % 100
|
| 85 |
+
|
| 86 |
+
if hundreds > 0:
|
| 87 |
+
part += self.ones[hundreds] + " hundred"
|
| 88 |
+
if tens_units > 0:
|
| 89 |
+
part += " and "
|
| 90 |
+
|
| 91 |
+
if tens_units > 10 and tens_units < 20:
|
| 92 |
+
part += self.teens[tens_units - 10]
|
| 93 |
+
else:
|
| 94 |
+
tens_digit = self.tens[tens_units // 10]
|
| 95 |
+
ones_digit = self.ones[tens_units % 10]
|
| 96 |
+
if tens_digit:
|
| 97 |
+
part += tens_digit
|
| 98 |
+
if ones_digit:
|
| 99 |
+
if tens_digit:
|
| 100 |
+
part += " "
|
| 101 |
+
part += ones_digit
|
| 102 |
+
|
| 103 |
+
parts.append(part)
|
| 104 |
+
|
| 105 |
+
num //= 1000
|
| 106 |
+
|
| 107 |
+
return " ".join(reversed(parts))
|
| 108 |
+
|
| 109 |
+
def convert(self, number):
|
| 110 |
+
"""
|
| 111 |
+
Converts an individual number passed in string form to spelt-out form
|
| 112 |
+
"""
|
| 113 |
+
if "." in number:
|
| 114 |
+
integer_part, decimal_part = number.split(".")
|
| 115 |
+
else:
|
| 116 |
+
integer_part, decimal_part = number, "00"
|
| 117 |
+
|
| 118 |
+
# Extract currency symbol if present
|
| 119 |
+
currency_symbol = ""
|
| 120 |
+
for symbol, name in self.currency_symbols.items():
|
| 121 |
+
if integer_part.startswith(symbol):
|
| 122 |
+
currency_symbol = name
|
| 123 |
+
integer_part = integer_part[len(symbol) :]
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
if integer_part.startswith("-"):
|
| 127 |
+
if integer_part[1:].startswith(symbol):
|
| 128 |
+
currency_symbol = name
|
| 129 |
+
integer_part = "-" + integer_part[len(symbol) + 1 :]
|
| 130 |
+
break
|
| 131 |
+
|
| 132 |
+
# Extract 'minus' prefix for negative numbers
|
| 133 |
+
minus_prefix = ""
|
| 134 |
+
if integer_part.startswith("-"):
|
| 135 |
+
minus_prefix = "minus "
|
| 136 |
+
integer_part = integer_part[1:]
|
| 137 |
+
elif integer_part.startswith("minus"):
|
| 138 |
+
minus_prefix = "minus "
|
| 139 |
+
integer_part = integer_part[len("minus") :]
|
| 140 |
+
|
| 141 |
+
percent_suffix = ""
|
| 142 |
+
if "%" in integer_part or "%" in decimal_part:
|
| 143 |
+
percent_suffix = " percent"
|
| 144 |
+
integer_part = integer_part.replace("%", "")
|
| 145 |
+
decimal_part = decimal_part.replace("%", "")
|
| 146 |
+
|
| 147 |
+
integer_part = integer_part.zfill(3 * ((len(integer_part) - 1) // 3 + 1))
|
| 148 |
+
|
| 149 |
+
parts = []
|
| 150 |
+
for i in range(0, len(integer_part), 3):
|
| 151 |
+
chunk = int(integer_part[i : i + 3])
|
| 152 |
+
if chunk > 0:
|
| 153 |
+
part = self.spell_number(chunk)
|
| 154 |
+
unit = self.thousands[len(integer_part[i:]) // 3 - 1]
|
| 155 |
+
if unit:
|
| 156 |
+
part += " " + unit
|
| 157 |
+
parts.append(part)
|
| 158 |
+
|
| 159 |
+
spelled_integer = " ".join(parts)
|
| 160 |
+
|
| 161 |
+
# Format the spelt-out number based on conditions, such as:
|
| 162 |
+
# If it has decimal parts, currency symbol, minus prefix, etc
|
| 163 |
+
if decimal_part == "00":
|
| 164 |
+
return (
|
| 165 |
+
f"{minus_prefix}{spelled_integer}{percent_suffix}{currency_symbol}"
|
| 166 |
+
if minus_prefix or currency_symbol
|
| 167 |
+
else f"{spelled_integer}{percent_suffix}"
|
| 168 |
+
)
|
| 169 |
+
else:
|
| 170 |
+
spelled_decimal = " ".join([self.spell_number(int(digit)) for digit in decimal_part])
|
| 171 |
+
return (
|
| 172 |
+
f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}{currency_symbol}"
|
| 173 |
+
if minus_prefix or currency_symbol
|
| 174 |
+
else f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def __call__(self, text):
|
| 178 |
+
"""
|
| 179 |
+
Convert numbers / number-like quantities in a string to their spelt-out counterparts
|
| 180 |
+
"""
|
| 181 |
+
# Form part of the pattern for all currency symbols
|
| 182 |
+
pattern = r"(?<!\w)(-?\$?\€?\£?\¢?\¥?\₹?\₽?\฿?\₺?\₴?\₣?\₡?\₱?\₪?\₮?\₩?\₦?\₫?\﷼?\d+(?:\.\d{1,2})?%?)(?!\w)"
|
| 183 |
+
|
| 184 |
+
# Find and replace commas in numbers (15,000 -> 15000, etc)
|
| 185 |
+
text = re.sub(r"(\d+,\d+)", lambda match: match.group(1).replace(",", ""), text)
|
| 186 |
+
|
| 187 |
+
# Use regex to find and replace numbers in the text
|
| 188 |
+
converted_text = re.sub(pattern, lambda match: self.convert(match.group(1)), text)
|
| 189 |
+
converted_text = re.sub(" +", " ", converted_text)
|
| 190 |
+
|
| 191 |
+
return converted_text
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/tokenization_speecht5.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Tokenization class for SpeechT5."""
|
| 15 |
+
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from ...tokenization_utils_sentencepiece import SentencePieceBackend
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
from ...utils.import_utils import requires
|
| 21 |
+
from .number_normalizer import EnglishNumberNormalizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
VOCAB_FILES_NAMES = {"vocab_file": "spm_char.model"}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@requires(backends=("sentencepiece",))
|
| 30 |
+
class SpeechT5Tokenizer(SentencePieceBackend):
|
| 31 |
+
"""
|
| 32 |
+
Construct a SpeechT5 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
|
| 33 |
+
|
| 34 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 35 |
+
this superclass for more information regarding those methods.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
vocab_file (`str`):
|
| 39 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 40 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 41 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 42 |
+
The begin of sequence token.
|
| 43 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 44 |
+
The end of sequence token.
|
| 45 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 46 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 47 |
+
token instead.
|
| 48 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 49 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 50 |
+
normalize (`bool`, *optional*, defaults to `False`):
|
| 51 |
+
Whether to convert numeric quantities in the text to their spelt-out english counterparts.
|
| 52 |
+
sp_model_kwargs (`dict`, *optional*):
|
| 53 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
| 54 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
| 55 |
+
to set:
|
| 56 |
+
|
| 57 |
+
- `enable_sampling`: Enable subword regularization.
|
| 58 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
| 59 |
+
|
| 60 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
| 61 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
| 62 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
| 63 |
+
using forward-filtering-and-backward-sampling algorithm.
|
| 64 |
+
|
| 65 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
| 66 |
+
BPE-dropout.
|
| 67 |
+
|
| 68 |
+
Attributes:
|
| 69 |
+
sp_model (`SentencePieceProcessor`):
|
| 70 |
+
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 74 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 75 |
+
is_fast = False
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
vocab_file,
|
| 80 |
+
bos_token="<s>",
|
| 81 |
+
eos_token="</s>",
|
| 82 |
+
unk_token="<unk>",
|
| 83 |
+
pad_token="<pad>",
|
| 84 |
+
normalize=False,
|
| 85 |
+
sp_model_kwargs: dict[str, Any] | None = None,
|
| 86 |
+
**kwargs,
|
| 87 |
+
) -> None:
|
| 88 |
+
self.normalize = normalize
|
| 89 |
+
self._normalizer = None
|
| 90 |
+
|
| 91 |
+
# Prepare sp_model_kwargs for parent class
|
| 92 |
+
if sp_model_kwargs is not None:
|
| 93 |
+
kwargs["sp_model_kwargs"] = sp_model_kwargs
|
| 94 |
+
|
| 95 |
+
# Call parent init (which will load sp_model)
|
| 96 |
+
super().__init__(
|
| 97 |
+
vocab_file=vocab_file,
|
| 98 |
+
bos_token=bos_token,
|
| 99 |
+
eos_token=eos_token,
|
| 100 |
+
unk_token=unk_token,
|
| 101 |
+
pad_token=pad_token,
|
| 102 |
+
normalize=normalize,
|
| 103 |
+
**kwargs,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
| 107 |
+
normalize = kwargs.pop("normalize", self.normalize)
|
| 108 |
+
if is_split_into_words:
|
| 109 |
+
text = " " + text
|
| 110 |
+
if normalize:
|
| 111 |
+
text = self.normalizer(text)
|
| 112 |
+
return (text, kwargs)
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def normalizer(self):
|
| 116 |
+
if self._normalizer is None:
|
| 117 |
+
self._normalizer = EnglishNumberNormalizer()
|
| 118 |
+
return self._normalizer
|
| 119 |
+
|
| 120 |
+
@normalizer.setter
|
| 121 |
+
def normalizer(self, value):
|
| 122 |
+
self._normalizer = value
|
| 123 |
+
|
| 124 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]:
|
| 125 |
+
"""Build model inputs from a sequence by appending eos_token_id."""
|
| 126 |
+
if token_ids_1 is None:
|
| 127 |
+
return token_ids_0 + [self.eos_token_id]
|
| 128 |
+
# We don't expect to process pairs, but leave the pair logic for API consistency
|
| 129 |
+
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
| 130 |
+
|
| 131 |
+
def get_special_tokens_mask(
|
| 132 |
+
self, token_ids_0: list[int], token_ids_1: list[int] | None = None, already_has_special_tokens: bool = False
|
| 133 |
+
) -> list[int]:
|
| 134 |
+
if already_has_special_tokens:
|
| 135 |
+
return super().get_special_tokens_mask(
|
| 136 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
suffix_ones = [1]
|
| 140 |
+
if token_ids_1 is None:
|
| 141 |
+
return ([0] * len(token_ids_0)) + suffix_ones
|
| 142 |
+
return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
| 143 |
+
|
| 144 |
+
def create_token_type_ids_from_sequences(
|
| 145 |
+
self, token_ids_0: list[int], token_ids_1: list[int] | None = None
|
| 146 |
+
) -> list[int]:
|
| 147 |
+
"""
|
| 148 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. SpeechT5 does not
|
| 149 |
+
make use of token type ids, therefore a list of zeros is returned.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
token_ids_0 (`list[int]`):
|
| 153 |
+
List of IDs.
|
| 154 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 155 |
+
Optional second list of IDs for sequence pairs.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
`list[int]`: List of zeros.
|
| 159 |
+
"""
|
| 160 |
+
eos = [self.eos_token_id]
|
| 161 |
+
if token_ids_1 is None:
|
| 162 |
+
return len(token_ids_0 + eos) * [0]
|
| 163 |
+
return len(token_ids_0 + token_ids_1 + eos) * [0]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
__all__ = ["SpeechT5Tokenizer"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_053000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c3a0ce15a3f8e0441fca84965ff658de402bc494bd94b53730430287ab2ab2df
|
| 3 |
+
size 927700322
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_163000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c731b03de97e220aab47f37b3d6c191aa23591d0c4f973a4b38e93730358b2bf
|
| 3 |
+
size 927700322
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_172000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa1869d751dd1db17309c432203e4d0978a22ab1fbe9065646c5d04cfe9baa67
|
| 3 |
+
size 927700322
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_182000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ca8affe97a9e4ab92c98e52693f27b329e99dd9122eb9b7672ab56618aaf840
|
| 3 |
+
size 927700322
|