Tarka-Embedding-350M-V1 / LFM_bidirectional_model.py
Jaswanth-0821's picture
Update LFM_bidirectional_model.py
2768c4d verified
from collections.abc import Callable
from typing import Any, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import Cache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.models.lfm2.configuration_lfm2 import Lfm2Config
from transformers.models.lfm2.modeling_lfm2 import (
Lfm2HybridConvCache,
Lfm2MLP,
Lfm2Model,
Lfm2RMSNorm,
Lfm2RotaryEmbedding,
Lfm2ShortConv,
Lfm2DecoderLayer,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.generic import check_model_inputs
from transformers.utils.import_utils import is_causal_conv1d_available
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_fn, causal_conv1d_update = None, None
class Lfm2BidirectionalConfig(Lfm2Config):
model_type = "lfm2_bidirec"
def __init__(self, pooling="avg", temperature=1.0, **kwargs):
self.pooling = pooling
self.temperature = temperature
super().__init__(**kwargs)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
# breakpoint()
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query.dtype
)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Lfm2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Lfm2Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
)
self.out_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
)
self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Lfm2HybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_layernorm(
self.q_proj(hidden_states).view(*hidden_shape)
).transpose(1, 2)
key_states = self.k_layernorm(
self.k_proj(hidden_states).view(*hidden_shape)
).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
**kwargs,
)
# breakpoint()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
output = self.out_proj(attn_output)
return output, attn_weights
def apply_mask_to_padding_states(hidden_states, attention_mask):
"""
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
"""
# NOTE: attention mask is a 2D boolean tensor
if (
attention_mask is not None
and attention_mask.shape[1] > 1
and attention_mask.shape[0] > 1
):
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
is_fast_path_available = all(kernel_modules)
@auto_docstring
class Lfm2PreTrainedModel(PreTrainedModel):
config: Lfm2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Lfm2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Lfm2DecoderLayer,
"attentions": Lfm2Attention,
}
@auto_docstring
class Lfm2BidirectionalModel(Lfm2PreTrainedModel):
config_class = Lfm2BidirectionalConfig
def __init__(self, config: Lfm2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
Lfm2DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.rotary_emb = Lfm2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
# Initialize weights and apply final processing
self.post_init()
for layer in self.layers:
if layer.is_attention_layer:
layer.self_attn.is_causal = False
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
# Generates bi-directional attention.
causal_mask = _prepare_4d_attention_mask(
attention_mask,
dtype=input_tensor.dtype,
)
return causal_mask
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Lfm2HybridConvCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
batch_size = inputs_embeds.shape[0]
past_key_values = Lfm2HybridConvCache(
config=self.config,
max_batch_size=batch_size,
dtype=self.dtype,
device=self.device,
)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
input_tensor=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
# Skip masking for decoding stage. We check shape here to be compile-friendly
linear_attention = attention_mask if inputs_embeds.shape[1] != 1 else None
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
# decoder layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
layer_mask = (
causal_mask if decoder_layer.is_attention_layer else linear_attention
)
hidden_states = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.embedding_norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)