|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributions |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from ...activations import ACT2FN |
|
|
from ...cache_utils import Cache, DynamicCache |
|
|
from ...generation import GenerationMixin |
|
|
from ...masking_utils import create_causal_mask |
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from ...modeling_layers import GradientCheckpointingLayer |
|
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
|
from ...processing_utils import Unpack |
|
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple |
|
|
from ...utils.deprecation import deprecate_kwarg |
|
|
from ...utils.generic import OutputRecorder, check_model_inputs |
|
|
from .configuration_blt import ( |
|
|
BltConfig, |
|
|
BltGlobalTransformerConfig, |
|
|
BltLocalDecoderConfig, |
|
|
BltLocalEncoderConfig, |
|
|
BltPatcherConfig, |
|
|
) |
|
|
|
|
|
|
|
|
class BltMLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
|
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
|
|
def forward(self, x): |
|
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
return down_proj |
|
|
|
|
|
|
|
|
class BltRMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): |
|
|
""" |
|
|
BltRMSNorm is equivalent to T5LayerNorm |
|
|
""" |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
input_dtype = hidden_states.dtype |
|
|
hidden_states = hidden_states.to(torch.float32) |
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
def extra_repr(self): |
|
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
|
|
|
|
class BltRotaryEmbedding(nn.Module): |
|
|
inv_freq: torch.Tensor |
|
|
|
|
|
def __init__(self, config: BltConfig, device=None): |
|
|
super().__init__() |
|
|
|
|
|
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): |
|
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
|
else: |
|
|
self.rope_type = "default" |
|
|
self.max_seq_len_cached = config.max_position_embeddings |
|
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
|
|
|
self.config = config |
|
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
self.original_inv_freq = self.inv_freq |
|
|
|
|
|
@torch.no_grad() |
|
|
@dynamic_rope_update |
|
|
def forward(self, x, position_ids): |
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
|
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
|
emb = torch.repeat_interleave(freqs, 2, dim=-1) |
|
|
cos = emb.cos() * self.attention_scaling |
|
|
sin = emb.sin() * self.attention_scaling |
|
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
|
|
|
class BltTransformerLayer(GradientCheckpointingLayer): |
|
|
def __init__(self, config, layer_idx: int): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx) |
|
|
self.mlp = BltMLP(config) |
|
|
self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
cross_attention_states: Optional[torch.Tensor] = None, |
|
|
cross_attention_mask: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
|
attention_mask (`torch.FloatTensor`, *optional*): |
|
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, |
|
|
query_sequence_length, key_sequence_length)` if default attention is used. |
|
|
|
|
|
use_cache (`bool`, *optional*): |
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
|
(see `past_key_values`). |
|
|
past_key_values (`Cache`, *optional*): cached past key and value projection states |
|
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
|
|
Indices depicting the position of the input sequence tokens in the sequence |
|
|
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): |
|
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, |
|
|
with `head_dim` being the embedding dimension of each attention head. |
|
|
kwargs (`dict`, *optional*): |
|
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code |
|
|
into the model |
|
|
""" |
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states, self_attn_weights = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
|
""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
|
def eager_attention_forward( |
|
|
module: nn.Module, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
scaling: float, |
|
|
dropout: float = 0.0, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
): |
|
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
|
if attention_mask is not None: |
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
|
attn_weights = attn_weights + causal_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
|
|
|
|
x1 = x[..., ::2] |
|
|
x2 = x[..., 1::2] |
|
|
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) |
|
|
return rot_x |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
|
|
Args: |
|
|
q (`torch.Tensor`): The query tensor. |
|
|
k (`torch.Tensor`): The key tensor. |
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
|
position_ids (`torch.Tensor`, *optional*): |
|
|
Deprecated and unused. |
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
|
Returns: |
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
|
""" |
|
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
class BltSelfAttention(nn.Module): |
|
|
def __init__(self, config: BltConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.dropout = config.dropout |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
self.head_dim = config.hidden_size // self.num_heads |
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
self.rope_theta = config.rope_theta |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
self.is_causal = True |
|
|
|
|
|
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
position_embeddings: torch.Tensor, |
|
|
use_cache: bool = False, |
|
|
past_key_values=None, |
|
|
cache_position=None, |
|
|
**kwargs, |
|
|
): |
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
cos, sin = position_embeddings |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
if past_key_values is not None: |
|
|
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
attention_interface: Callable = eager_attention_forward |
|
|
|
|
|
if self.config._attn_implementation != "eager": |
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
|
|
attn_output, attn_weights = attention_interface( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
dropout=0.0 if not self.training else self.dropout, |
|
|
scaling=self.scaling, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() |
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
class BltCrossAttention(nn.Module): |
|
|
"""Cross-attention module for Blt, following transformers style""" |
|
|
|
|
|
def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.num_heads = self.config.num_attention_heads |
|
|
self.num_key_value_heads = self.config.num_key_value_heads |
|
|
self.dropout = config.dropout |
|
|
self.hidden_size = config.hidden_size |
|
|
self.head_dim = config.hidden_size // self.num_heads |
|
|
self.layer_idx = layer_idx |
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
|
|
self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
|
|
self.is_causal = False |
|
|
|
|
|
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
cross_attention_states: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
|
|
"""Input shape: Batch x Time x Channel""" |
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
query_states = self.q_norm(hidden_states) |
|
|
query_states = self.q_proj(query_states) |
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
if cross_attention_states is not None: |
|
|
cross_attention_states = self.k_norm(cross_attention_states) |
|
|
key_states = self.k_proj(cross_attention_states) |
|
|
value_states = self.v_proj(cross_attention_states) |
|
|
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
if past_key_values is not None: |
|
|
key_states, value_states = past_key_values.update( |
|
|
key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
|
|
) |
|
|
elif cache_position[0] != 0: |
|
|
key_states, value_states = ( |
|
|
past_key_values.layers[self.layer_idx].keys, |
|
|
past_key_values.layers[self.layer_idx].values, |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" |
|
|
) |
|
|
attention_interface: Callable = eager_attention_forward |
|
|
|
|
|
if self.config._attn_implementation != "eager": |
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
|
|
attn_output, attn_weights = attention_interface( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
dropout=0.0 if not self.training else self.dropout, |
|
|
scaling=self.scaling, |
|
|
**kwargs, |
|
|
) |
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() |
|
|
attn_output = self.o_proj(attn_output) |
|
|
attn_output = attn_output + hidden_states |
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BltPreTrainedModel(PreTrainedModel): |
|
|
config: BltConfig |
|
|
base_model_prefix = "" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["BltTransformerLayer"] |
|
|
_can_compile_fullgraph = False |
|
|
_supports_sdpa = True |
|
|
_supports_flash_attn = False |
|
|
_supports_flex_attn = False |
|
|
_supports_attention_backend = False |
|
|
_can_record_outputs = { |
|
|
"hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), |
|
|
"attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), |
|
|
} |
|
|
|
|
|
|
|
|
class BltLocalEncoder(BltPreTrainedModel): |
|
|
config: BltLocalEncoderConfig |
|
|
_can_record_outputs = { |
|
|
"encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), |
|
|
} |
|
|
|
|
|
def __init__(self, config: BltLocalEncoderConfig): |
|
|
super().__init__(config) |
|
|
self.gradient_checkpointing = False |
|
|
self.config = config |
|
|
self.layers = nn.ModuleList( |
|
|
[BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
|
) |
|
|
self.rotary_emb = BltRotaryEmbedding(config=config) |
|
|
self.patch_embedding_projection = nn.Linear( |
|
|
in_features=config.hidden_size, |
|
|
out_features=config.hidden_size * config.cross_attn_k, |
|
|
bias=False, |
|
|
) |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.cross_attn_layers = nn.ModuleList() |
|
|
layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 |
|
|
for layer_idx in range(layers_to_add): |
|
|
self.cross_attn_layers.append( |
|
|
BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
patch_embeds: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
num_patches: Optional[int] = None, |
|
|
patch_ids: Optional[torch.Tensor] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
): |
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
batch_size = inputs_embeds.shape[0] |
|
|
hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = ( |
|
|
torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) |
|
|
) |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) |
|
|
|
|
|
for idx, layer in enumerate(self.layers): |
|
|
hidden_states = layer( |
|
|
hidden_states, |
|
|
position_embeddings=position_embeddings, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: |
|
|
patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids) |
|
|
patch_embeds = self.patch_embedding_projection(patch_embeds) |
|
|
patch_embeds = patch_embeds.reshape( |
|
|
batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size |
|
|
) |
|
|
layer_idx = idx if self.config.cross_attn_all_layers else 0 |
|
|
cross_attention_output, _ = self.cross_attn_layers[layer_idx]( |
|
|
hidden_states=patch_embeds, |
|
|
cross_attention_states=hidden_states, |
|
|
attention_mask=encoder_attention_mask, |
|
|
**kwargs, |
|
|
) |
|
|
patch_embeds = patch_embeds + cross_attention_output |
|
|
encoder_cross_states = patch_embeds |
|
|
return hidden_states, encoder_cross_states |
|
|
|
|
|
def patch_reduce(self, hidden_states, max_num_patches, patch_ids): |
|
|
""" |
|
|
Reduce variable length patches to single embedding per patch |
|
|
Note: this works with variable number of patches for different sequences in the batch |
|
|
It handles variable length patches by assuming that patch_lengths will be 0 for any |
|
|
extra patches on the *right*. Since there can be a variable number of patches |
|
|
this function also return the number of patches for each sequence in the batch. |
|
|
Any embeddings on the right that are not allocated to a patch |
|
|
(i.e. if the sum(patch_lengths[i]) < seq_len for any i) |
|
|
will be sent to a dummy patch, which is trimmed before returning. |
|
|
""" |
|
|
batch_size = hidden_states.shape[0] |
|
|
embedding_dim = hidden_states.shape[-1] |
|
|
|
|
|
patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) |
|
|
|
|
|
reduced_embeddings = torch.zeros( |
|
|
(batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device |
|
|
) |
|
|
|
|
|
reduced_embeddings = reduced_embeddings.scatter_reduce( |
|
|
src=hidden_states, |
|
|
dim=1, |
|
|
index=patch_ids, |
|
|
reduce="amax", |
|
|
include_self=False, |
|
|
) |
|
|
reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] |
|
|
|
|
|
return reduced_embeddings |
|
|
|
|
|
|
|
|
class BltLocalDecoder(BltPreTrainedModel): |
|
|
config: BltLocalDecoderConfig |
|
|
|
|
|
def __init__(self, config: BltLocalDecoderConfig): |
|
|
super().__init__(config) |
|
|
self.gradient_checkpointing = False |
|
|
self.config = config |
|
|
self.cross_attn_decoder = True |
|
|
self.layers = nn.ModuleList( |
|
|
[BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
|
) |
|
|
self.rotary_emb = BltRotaryEmbedding(config=config) |
|
|
self.patch_embedding_projection = nn.Linear( |
|
|
in_features=config.hidden_size_global, |
|
|
out_features=config.hidden_size * config.cross_attn_k, |
|
|
bias=False, |
|
|
) |
|
|
self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.cross_attn_layers = nn.ModuleList() |
|
|
layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 |
|
|
for layer_idx in range(layers_to_add): |
|
|
self.cross_attn_layers.append( |
|
|
BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@check_model_inputs() |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
patch_embeds: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
): |
|
|
batch_size = inputs_embeds.shape[0] |
|
|
hidden_states = inputs_embeds |
|
|
patch_embeds = self.patch_embedding_projection(patch_embeds) |
|
|
patch_embeds = patch_embeds.reshape( |
|
|
batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size |
|
|
) |
|
|
|
|
|
if patch_embeds is not None and not self.cross_attn_decoder: |
|
|
hidden_states = hidden_states + patch_embeds |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = ( |
|
|
torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) |
|
|
) |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) |
|
|
|
|
|
for i, layer in enumerate(self.layers): |
|
|
if i == 0 or self.config.cross_attn_all_layers: |
|
|
cross_attention_output, _ = self.cross_attn_layers[i]( |
|
|
hidden_states=hidden_states, |
|
|
cross_attention_states=patch_embeds, |
|
|
attention_mask=encoder_attention_mask, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = hidden_states + cross_attention_output |
|
|
hidden_states = layer( |
|
|
hidden_states, |
|
|
position_embeddings=position_embeddings, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
logits = self.norm(hidden_states) |
|
|
return logits |
|
|
|
|
|
|
|
|
class BltGlobalTransformer(BltPreTrainedModel): |
|
|
config: BltGlobalTransformerConfig |
|
|
_can_record_outputs = { |
|
|
"global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), |
|
|
} |
|
|
|
|
|
def __init__(self, config: BltGlobalTransformerConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.layers = nn.ModuleList() |
|
|
for layer_idx in range(config.num_hidden_layers): |
|
|
self.layers.append(BltTransformerLayer(config, layer_idx)) |
|
|
self.rotary_emb = BltRotaryEmbedding(config=config) |
|
|
|
|
|
|
|
|
if getattr(config, "encoder_cross_output_size", None) is not None: |
|
|
self.token_embedding_projection = nn.Linear( |
|
|
config.encoder_cross_output_size, config.hidden_size, bias=False |
|
|
) |
|
|
else: |
|
|
self.token_embedding_projection = nn.Identity() |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_embeds: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
): |
|
|
batch_size, seq_len, _ = input_embeds.shape |
|
|
hidden_states = self.token_embedding_projection(input_embeds) |
|
|
hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) |
|
|
if position_ids is None: |
|
|
position_ids = ( |
|
|
torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) |
|
|
) |
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
for i, layer in enumerate(self.layers): |
|
|
hidden_states = layer( |
|
|
hidden_states, |
|
|
position_embeddings=position_embeddings, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: |
|
|
""" |
|
|
Splits patch lengths into smaller segments if they exceed `max_patch_length`. |
|
|
Pads the result to uniform length across the batch. |
|
|
|
|
|
Args: |
|
|
patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. |
|
|
max_patch_length (int, optional): Maximum allowed length per patch. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. |
|
|
""" |
|
|
if max_patch_length is None: |
|
|
return patch_lengths |
|
|
|
|
|
batch_size = patch_lengths.size(0) |
|
|
processed = [] |
|
|
|
|
|
for seq in patch_lengths: |
|
|
splits = [] |
|
|
for length in seq[seq > 0]: |
|
|
length = length.item() |
|
|
full_chunks, remainder = divmod(length, max_patch_length) |
|
|
splits.extend([max_patch_length] * full_chunks) |
|
|
if remainder: |
|
|
splits.append(remainder) |
|
|
processed.append(splits) |
|
|
|
|
|
|
|
|
max_len = max(len(splits) for splits in processed) |
|
|
padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) |
|
|
|
|
|
for i, splits in enumerate(processed): |
|
|
if splits: |
|
|
padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) |
|
|
|
|
|
|
|
|
if (padded != 0).any(dim=0).sum() < padded.shape[1]: |
|
|
last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 |
|
|
padded = padded[:, :last_nonzero] |
|
|
|
|
|
return padded |
|
|
|
|
|
|
|
|
class BltPatcher(BltPreTrainedModel): |
|
|
config: BltPatcherConfig |
|
|
|
|
|
def __init__(self, config: BltPatcherConfig): |
|
|
super().__init__(config) |
|
|
self.rotary_emb = BltRotaryEmbedding(config=self.config) |
|
|
self.layers = nn.ModuleList() |
|
|
for layer_idx in range(self.config.num_hidden_layers): |
|
|
self.layers.append(BltTransformerLayer(self.config, layer_idx)) |
|
|
self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) |
|
|
self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) |
|
|
self.lm_head = nn.Linear( |
|
|
self.config.hidden_size, |
|
|
self.config.vocab_size, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
patch_size: Optional[int] = None, |
|
|
threshold: Optional[float] = None, |
|
|
max_patch_length: Optional[int] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
): |
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
if use_cache and past_key_values is None: |
|
|
past_key_values = DynamicCache() |
|
|
|
|
|
if cache_position is None: |
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
cache_position = torch.arange( |
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
|
) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
causal_mask = create_causal_mask( |
|
|
config=self.config, |
|
|
input_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
cache_position=cache_position, |
|
|
past_key_values=past_key_values, |
|
|
position_ids=position_ids, |
|
|
) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
for layer in self.layers: |
|
|
hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask) |
|
|
|
|
|
logits = self.lm_head(self.norm(hidden_states)) |
|
|
prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() |
|
|
|
|
|
batch_size, sequence_length = inputs_embeds.shape[:2] |
|
|
if patch_size is not None: |
|
|
patch_lengths = self.patch_lengths_from_entropies( |
|
|
entropies=prediction_entropies, |
|
|
sequence_length=sequence_length, |
|
|
patch_size=patch_size, |
|
|
threshold=threshold, |
|
|
) |
|
|
else: |
|
|
patch_lengths = torch.ones( |
|
|
(batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device |
|
|
) |
|
|
patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) |
|
|
return prediction_entropies, patch_lengths, logits |
|
|
|
|
|
@staticmethod |
|
|
def patch_lengths_from_entropies( |
|
|
entropies, |
|
|
sequence_length, |
|
|
patch_size=None, |
|
|
threshold=None, |
|
|
): |
|
|
""" |
|
|
Computes patch lengths from token entropies. |
|
|
|
|
|
Depending on whether a threshold is provided, the function uses either: |
|
|
- Thresholding the entropy values (when `threshold` is set). |
|
|
""" |
|
|
|
|
|
batch_size = entropies.shape[0] |
|
|
|
|
|
|
|
|
init_tokens = ( |
|
|
torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) |
|
|
) |
|
|
offset = init_tokens.shape[1] |
|
|
|
|
|
|
|
|
entropies = entropies[:, 1:] |
|
|
|
|
|
|
|
|
patch_mask = entropies > threshold |
|
|
|
|
|
seq_len = patch_mask.shape[1] |
|
|
|
|
|
|
|
|
token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) |
|
|
sentinel = torch.full_like(token_indices, seq_len) |
|
|
padded_indices = torch.cat([token_indices, sentinel], dim=1) |
|
|
|
|
|
|
|
|
padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) |
|
|
|
|
|
|
|
|
patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) |
|
|
max_valid_patches = patch_mask.sum(dim=1).max() |
|
|
patch_starts = patch_starts[:, :max_valid_patches] |
|
|
|
|
|
|
|
|
patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) |
|
|
|
|
|
|
|
|
last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) |
|
|
patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) |
|
|
|
|
|
patch_lengths = patch_ends - patch_start_ids + 1 |
|
|
|
|
|
return patch_lengths |
|
|
|
|
|
|
|
|
def rolling_polynomial_hash(token_tensor, prime: int = 1000000007): |
|
|
""" |
|
|
A polynomial rolling hash algorithm that converts sequences |
|
|
of tokens into hash values. The hash is computed as: |
|
|
hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n) |
|
|
|
|
|
The rolling hash allows the model to efficiently |
|
|
identify and encode recurring byte-level patterns in the input text. |
|
|
|
|
|
Args: |
|
|
token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash |
|
|
prime (int): Prime number used as the base for the polynomial hash. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Hash values of shape [batch_size, seq_len] where each value |
|
|
represents the hash of the corresponding token group |
|
|
|
|
|
Example: |
|
|
>>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]]) |
|
|
>>> hashes = rolling_polynomial_hash(tokens, prime=31) |
|
|
>>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2 |
|
|
>>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2 |
|
|
""" |
|
|
prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device) |
|
|
powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) |
|
|
prime_powers = prime_tensor**powers |
|
|
return torch.sum(token_tensor * prime_powers, dim=-1) |
|
|
|
|
|
|
|
|
def byte_group_hash_function( |
|
|
token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000 |
|
|
): |
|
|
"""Hash token groups and map to range [0, max_hash].""" |
|
|
with torch.no_grad(): |
|
|
batch_size, seq_len = token_ids.shape |
|
|
|
|
|
padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) |
|
|
padded_tokens = torch.cat([padding, token_ids], dim=1) |
|
|
|
|
|
|
|
|
windows = padded_tokens.unfold(1, group_size, 1) |
|
|
hashes = rolling_polynomial_hash(windows, prime) |
|
|
hash_values = hashes % max_hash |
|
|
|
|
|
return hash_values |
|
|
|
|
|
|
|
|
def compute_hash_embeddings( |
|
|
local_encoder_tokens: torch.Tensor, |
|
|
local_encoder, |
|
|
encoder_hash_tok_embedding: nn.Embedding, |
|
|
encoder_hash_byte_group_nb_functions: int, |
|
|
encoder_hash_byte_group_size: list, |
|
|
encoder_hash_byte_group_vocab: int, |
|
|
) -> torch.Tensor: |
|
|
"""Compute token embeddings enhanced with hash-based embeddings.""" |
|
|
|
|
|
primes = [ |
|
|
1000000007, |
|
|
5915587277, |
|
|
1500450271, |
|
|
3267000013, |
|
|
5754853343, |
|
|
4093082899, |
|
|
9576890767, |
|
|
3628273133, |
|
|
2860486313, |
|
|
5463458053, |
|
|
3367900313, |
|
|
] |
|
|
|
|
|
embeddings = local_encoder.embed_tokens(local_encoder_tokens) |
|
|
embedding_idx = 0 |
|
|
for func_nb in range(encoder_hash_byte_group_nb_functions): |
|
|
prime = primes[func_nb % len(primes)] |
|
|
for group_size in encoder_hash_byte_group_size: |
|
|
hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab) |
|
|
|
|
|
offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab |
|
|
embeddings += encoder_hash_tok_embedding(offset_hash_ids) |
|
|
embedding_idx += 1 |
|
|
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
def _prepare_patch_cross_attention_mask( |
|
|
patch_ids: torch.Tensor, |
|
|
num_patches: int, |
|
|
sequence_length: int, |
|
|
patches_as_queries: bool = False, |
|
|
cross_attn_k: int = 1, |
|
|
dtype: torch.dtype = torch.float32, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Prepare cross-attention mask for patch-based attention, following mllama's robust approach. |
|
|
|
|
|
This function creates masks that control which patches can attend to which other patches, |
|
|
with support for query/key role swapping and cross-attention multipliers. |
|
|
|
|
|
Args: |
|
|
patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. |
|
|
num_patches (int): Total number of patches. |
|
|
sequence_length (int): Length of the sequence. |
|
|
patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. |
|
|
cross_attn_k (int): Cross-attention multiplier for repeating patches. |
|
|
dtype (torch.dtype): Data type for the output mask. |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, torch.Tensor]: |
|
|
- cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] |
|
|
""" |
|
|
batch_size, seq_len = patch_ids.shape |
|
|
device = patch_ids.device |
|
|
|
|
|
|
|
|
if patches_as_queries: |
|
|
q_len = num_patches * cross_attn_k |
|
|
kv_len = sequence_length |
|
|
|
|
|
q_patch_ids = ( |
|
|
torch.arange(num_patches, device=device) |
|
|
.unsqueeze(0) |
|
|
.unsqueeze(-1) |
|
|
.expand(batch_size, num_patches, seq_len) |
|
|
) |
|
|
kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) |
|
|
else: |
|
|
q_len = sequence_length |
|
|
kv_len = num_patches * cross_attn_k |
|
|
|
|
|
q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) |
|
|
kv_patch_ids = ( |
|
|
torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
cross_attention_mask = q_patch_ids == kv_patch_ids |
|
|
|
|
|
|
|
|
repeat_dim = 1 if patches_as_queries else -1 |
|
|
cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) |
|
|
|
|
|
|
|
|
expected_shape = (batch_size, q_len, kv_len) |
|
|
if cross_attention_mask.shape != expected_shape: |
|
|
raise ValueError( |
|
|
f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}" |
|
|
) |
|
|
|
|
|
|
|
|
cross_attention_mask = cross_attention_mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype) |
|
|
cross_attention_mask = inverted_cross_attn_mask.masked_fill( |
|
|
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min |
|
|
) |
|
|
|
|
|
return cross_attention_mask |
|
|
|
|
|
|
|
|
class BltModel(BltPreTrainedModel): |
|
|
def __init__(self, config: BltConfig): |
|
|
super().__init__(config) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.config = config |
|
|
self.local_encoder = BltLocalEncoder(config.encoder_config) |
|
|
self.global_transformer = BltGlobalTransformer(config.global_config) |
|
|
self.local_decoder = BltLocalDecoder(config.decoder_config) |
|
|
num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) |
|
|
total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings |
|
|
self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size) |
|
|
if self.config.patch_in_forward: |
|
|
self.patcher = BltPatcher(config.patcher_config) |
|
|
self.patcher.eval() |
|
|
for param in self.patcher.parameters(): |
|
|
param.requires_grad = False |
|
|
else: |
|
|
self.patcher = None |
|
|
self.post_init() |
|
|
|
|
|
@check_model_inputs() |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
patch_lengths: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> BaseModelOutputWithPast: |
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
|
|
|
|
|
if inputs_embeds is not None: |
|
|
encoder_embeds = inputs_embeds |
|
|
batch_size, sequence_length, _ = inputs_embeds.shape |
|
|
else: |
|
|
batch_size, sequence_length = input_ids.shape |
|
|
encoder_embeds = compute_hash_embeddings( |
|
|
input_ids, |
|
|
self.local_encoder, |
|
|
self.encoder_hash_tok_embedding, |
|
|
self.config.encoder_hash_byte_group_nb_functions, |
|
|
self.config.encoder_hash_byte_group_size, |
|
|
self.config.encoder_hash_byte_group_vocab, |
|
|
) |
|
|
|
|
|
if patch_lengths is None: |
|
|
if self.config.patching_mode == "entropy" and self.patcher is not None: |
|
|
if input_ids is None: |
|
|
raise ValueError("input_ids is required for entropy-based patching") |
|
|
_, patch_lengths, _ = self.patcher( |
|
|
input_ids, |
|
|
patch_size=self.config.patch_size, |
|
|
threshold=self.config.patching_threshold, |
|
|
max_patch_length=self.config.max_patch_length, |
|
|
patching_batch_size=self.config.patching_batch_size, |
|
|
device=input_ids.device, |
|
|
) |
|
|
else: |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype |
|
|
patch_lengths = process_patch_lengths( |
|
|
torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device), |
|
|
self.config.max_patch_length, |
|
|
) |
|
|
patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) |
|
|
if cache_position is None: |
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
cache_position = torch.arange( |
|
|
past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device |
|
|
) |
|
|
if position_ids is None: |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
causal_mask = create_causal_mask( |
|
|
config=self.config, |
|
|
input_embeds=encoder_embeds, |
|
|
attention_mask=attention_mask, |
|
|
cache_position=cache_position, |
|
|
past_key_values=past_key_values, |
|
|
position_ids=position_ids, |
|
|
) |
|
|
|
|
|
cross_attn_mask_enc = _prepare_patch_cross_attention_mask( |
|
|
patch_ids=patch_ids, |
|
|
num_patches=patch_lengths.shape[1], |
|
|
sequence_length=sequence_length, |
|
|
patches_as_queries=True, |
|
|
cross_attn_k=self.config.cross_attn_k, |
|
|
dtype=encoder_embeds.dtype, |
|
|
) |
|
|
encoder_hidden_states, encoder_cross_states = self.local_encoder( |
|
|
input_ids=input_ids, |
|
|
inputs_embeds=encoder_embeds, |
|
|
attention_mask=causal_mask, |
|
|
position_ids=position_ids, |
|
|
encoder_attention_mask=cross_attn_mask_enc, |
|
|
num_patches=patch_lengths.shape[1], |
|
|
patch_ids=patch_ids, |
|
|
**kwargs, |
|
|
) |
|
|
encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) |
|
|
global_cache_position = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device) |
|
|
global_position_ids = global_cache_position.unsqueeze(0) |
|
|
global_causal_mask = create_causal_mask( |
|
|
config=self.config, |
|
|
input_embeds=encoder_cross_states, |
|
|
attention_mask=None, |
|
|
cache_position=global_cache_position, |
|
|
past_key_values=None, |
|
|
position_ids=None, |
|
|
) |
|
|
|
|
|
global_hidden_states = self.global_transformer( |
|
|
input_embeds=encoder_cross_states, |
|
|
attention_mask=global_causal_mask, |
|
|
position_ids=global_position_ids, |
|
|
**kwargs, |
|
|
) |
|
|
decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) |
|
|
cross_attn_mask_dec = _prepare_patch_cross_attention_mask( |
|
|
patch_ids=decoder_patch_ids, |
|
|
num_patches=patch_lengths.shape[1], |
|
|
sequence_length=sequence_length, |
|
|
patches_as_queries=False, |
|
|
cross_attn_k=self.config.cross_attn_k, |
|
|
dtype=encoder_embeds.dtype, |
|
|
) |
|
|
output = self.local_decoder( |
|
|
input_ids=input_ids, |
|
|
inputs_embeds=encoder_hidden_states, |
|
|
patch_embeds=global_hidden_states, |
|
|
attention_mask=causal_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=cache_position, |
|
|
encoder_attention_mask=cross_attn_mask_dec, |
|
|
**kwargs, |
|
|
) |
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=output, |
|
|
past_key_values=past_key_values, |
|
|
) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.local_encoder.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.local_encoder.embed_tokens = value |
|
|
|
|
|
def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: |
|
|
batch_size = patch_lengths.shape[0] |
|
|
patch_starts = torch.cat( |
|
|
[ |
|
|
torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), |
|
|
patch_lengths.cumsum(dim=-1)[:, :-1], |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
token_positions = torch.arange(seq_len, device=patch_lengths.device) |
|
|
return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
The Blt Text Model with a language modeling head on top. |
|
|
""" |
|
|
) |
|
|
class BltForCausalLM(BltPreTrainedModel, GenerationMixin): |
|
|
config: BltConfig |
|
|
_can_compile_fullgraph = False |
|
|
base_model_prefix = "model" |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config: BltConfig): |
|
|
super().__init__(config.get_text_config()) |
|
|
self.text_config = config.get_text_config() |
|
|
self.vocab_size = config.vocab_size |
|
|
self.model = BltModel(config) |
|
|
self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@can_return_tuple |
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
cross_attention_states: Optional[torch.LongTensor] = None, |
|
|
cross_attention_mask: Optional[torch.LongTensor] = None, |
|
|
full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> Union[tuple, CausalLMOutputWithPast]: |
|
|
r""" |
|
|
cross_attention_states (`torch.FloatTensor`, *optional*): |
|
|
Output of the vision model, used for cross-attention. This tensor contains the processed image features that |
|
|
the language model will attend to. |
|
|
cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): |
|
|
Cross-attention mask to control the interaction between text tokens and image tiles. |
|
|
This 4D tensor defines which image tiles each text token should attend to. |
|
|
|
|
|
For each text token (in seq_length): |
|
|
- 1 indicates the token **should attend** to the corresponding image tile |
|
|
- 0 indicates the token **should not attend** to the corresponding image tile |
|
|
full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*): |
|
|
A tuple containing two tensors that mask out rows in the cross-attention mechanism: |
|
|
- The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1. |
|
|
A value of 0 indicates that the corresponding text token's entire row in the cross-attention |
|
|
matrix should be masked out (all image tokens ignored). |
|
|
- The second tensor has the same shape and is used internally to apply the masking during |
|
|
the forward pass of cross-attention layers. |
|
|
This mask is derived from the cross_attention_mask and is used to handle cases where a text token |
|
|
should not attend to any image token. |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> from transformers import AutoTokenizer, BltForCausalLM |
|
|
|
|
|
>>> model = BltForCausalLM.from_pretrained("Llama-3.2-11B-Vision") |
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") |
|
|
|
|
|
>>> prompt = "If I had to write a haiku, it would be:" |
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
>>> # Generate |
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) |
|
|
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
|
>>> print(result) |
|
|
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. |
|
|
I love the idea of snowflakes gently falling, each one |
|
|
``` |
|
|
""" |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
cross_attention_mask=cross_attention_mask, |
|
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]).float() |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM"] |
|
|
|