diff --git "a/modeling_jamba.py" "b/modeling_jamba.py" new file mode 100644--- /dev/null +++ "b/modeling_jamba.py" @@ -0,0 +1,4438 @@ +# coding=utf-8 +# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Jamba model.""" +import inspect +import math +import copy +import warnings +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union +import time +from collections import OrderedDict +from functools import partial +import numpy as np +import os + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_jamba import JambaConfig +from torch.utils.checkpoint import checkpoint + + +# try except block so it'll work with trust_remote_code. Later we can have `if is_flash_attn_2_available():` +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + from einops import rearrange, repeat, reduce, pack, unpack + from einops.layers.torch import Rearrange +except ImportError: + pass + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + +# try except block so it'll work with trust_remote_code. Later we can have `if is_mamba_ssm_available():` +# try: +from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +# except ImportError: +# selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + +# try except block so it'll work with trust_remote_code. Later we can have `if is_causal_conv1d_available():` +# try: +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +# except ImportError: +# causal_conv1d_update, causal_conv1d_fn = None, None + + +from .mamba2 import Mamba2, Mamba2_Multihead, Mamba2_Fused +# from .retention import MultiScaleRetention +# from .gla import GatedLinearAttention +# from fla.layers.gated_deltanet import GatedDeltaNet + +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) + +# from .fused_rotary_embedding import apply_rotary_emb_func, RoPECache, build_rope_cache +# torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False, enable_cudnn=True) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "JambaConfig" + + +def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.): + if pad == (0, 0): + return t + + dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = ((0, 0) * dims_from_right) + return F.pad(t, (*zeros, *pad), value = value) + +# Adapted from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits if layer_gate.shape[1] > 1], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +### Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba +class JambaRMSNorm(nn.Module): + def __init__(self, hidden_size, learnable_weight=True, eps=1e-6): + """ + JambaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + if learnable_weight: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.weight = None + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + if self.weight is not None: + return self.weight * hidden_states.to(input_dtype) + else: + return hidden_states.to(input_dtype) + + +class PerheadJambaRMSNorm(nn.Module): + def __init__(self, hidden_size, num_heads, eps=1e-6): + """ + For per-head kq normalization + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(1, num_heads, 1, hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # assert 1==0, f"hiddens_states shape: {hidden_states.shape}" # [bsz, num_heads, seq_len, head_dim] + assert hidden_states.shape[1] == self.weight.shape[1], f"hidden_state: {hidden_states.shape}, weight: {self.weight.shape}" + assert hidden_states.shape[3] == self.weight.shape[3], f"hidden_state: {hidden_states.shape}, weight: {self.weight.shape}" + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + # variance = hidden_states.pow(2).mean(-1, keepdim=True) + # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # return self.weight * hidden_states.to(input_dtype) + + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class JambaOnlyNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + JambaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + # self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return hidden_states.to(input_dtype) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, config, dim, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.base = base + self.config = config + + self.rope_type = config.rope_type + + self.factor = 2 + + max_position_embeddings = self.config.max_position_embeddings + + if config.rope_type is None or config.rope_type == "default": + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.max_seq_len_cached = max_position_embeddings + + elif config.rope_type == 'ntk': + assert self.config.orig_max_position_embeddings is not None + orig_max_position_embeddings = self.config.orig_max_position_embeddings + + base = base * ((self.factor * max_position_embeddings / orig_max_position_embeddings) - (self.factor - 1)) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + + self.max_seq_len_cached = orig_max_position_embeddings + + elif config.rope_type == 'dynamic_ntk': + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.original_inv_freq = inv_freq + self.max_seq_len_cached = self.config.orig_max_position_embeddings + + else: + raise ValueError(f"Not support rope_type: {config.rope_type}") + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + base = self.base * ((self.factor * seq_len / self.config.orig_max_position_embeddings) - (self.factor - 1)) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_seq_len_cached = seq_len + + if seq_len < self.config.orig_max_position_embeddings and self.max_seq_len_cached > self.config.orig_max_position_embeddings: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.config.orig_max_position_embeddings + + + + @torch.no_grad() + def forward(self, x, position_ids): + if self.rope_type == 'dynamic_ntk': + self._dynamic_frequency_update(position_ids, device=x.device) + + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + if q is not None: + q_embed = (q * cos) + (rotate_half(q) * sin) + + else: + q_embed = None + + if k is not None: + k_embed = (k * cos) + (rotate_half(k) * sin) + else: + k_embed = None + return q_embed, k_embed + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + + +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None): + self.dtype = dtype + # self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + intermediate_size = config.mamba_expand * config.hidden_size + ssm_state_size = config.mamba_d_state + conv_kernel_size = config.mamba_d_conv + self.conv_states = [] + self.ssm_states = [] + + self.layer_type = layer_type + + for i in range(config.num_hidden_layers): + if layer_type is None: + is_attn = True if (i>=config.attn_layer_offset) and ((i - config.attn_layer_offset) % config.attn_layer_period == 0) else False + has_mamba_state = i in config.hybrid_block_indices or not is_attn or config.fused_multihead_config is not None + else: + has_mamba_state = self.layer_type[i] == 'h' or self.layer_type[i] == 'm' + + if has_mamba_state: + if hasattr(config, 'conv_dim'): + conv_dim = config.conv_dim[str(i)] + else: + conv_dim = intermediate_size + self.conv_states += [ + torch.zeros(batch_size, conv_dim, conv_kernel_size, device=device, dtype=dtype) + ] + self.ssm_states += [ + torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) + ] + else: + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + self.mamba_past_length = [0 for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + + if self.layer_type[layer_idx] == 'm': + return self.mamba_past_length[layer_idx] + + if self.key_cache[layer_idx].shape[-1] == 0: + return 0 + + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + +@dataclass +class MambaCacheParams: + seqlen_offset: int = 0 + conv_states: Dict[int, torch.Tensor] = field(default_factory=dict) + ssm_states: Dict[int, torch.Tensor] = field(default_factory=dict) + + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba +class JambaAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: JambaConfig, layer_idx: Optional[int] = None, compact_gating=False, reuse_kv=False, attn_only_wo_proj=False, use_linear_attn=False, input_hidden_size=None, output_hidden_size=None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # self.hidden_size = config.hidden_size + self.hidden_size = config.attn_hidden_size if config.attn_hidden_size > 0 else config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.compact_gating = compact_gating + + self.attn_only_wo_proj = attn_only_wo_proj + + self.kq_head_dim = config.kq_head_dim if config.kq_head_dim > 0 else self.head_dim + self.v_head_dim = config.v_head_dim if config.v_head_dim > 0 else self.head_dim + + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size and self.kq_head_dim == self.head_dim: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + if not self.attn_only_wo_proj: + self.q_proj = nn.Linear(self.hidden_size if input_hidden_size is None else input_hidden_size, self.num_heads * self.kq_head_dim, bias=False) + + self.reuse_kv = reuse_kv + + if not self.attn_only_wo_proj and not self.reuse_kv: + self.k_proj = nn.Linear(self.hidden_size if input_hidden_size is None else input_hidden_size, self.num_key_value_heads * self.kq_head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size if input_hidden_size is None else input_hidden_size, self.num_key_value_heads * self.v_head_dim, bias=False) + + if output_hidden_size is None: + output_hidden_size = self.hidden_size + + if not self.attn_only_wo_proj and not self.compact_gating: + self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, output_hidden_size, bias=False) + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['qk_norm']): + self.k_norm = None + self.q_norm = None + else: + if self.config.kq_norm == "rms": + self.k_norm = JambaRMSNorm(self.kq_head_dim) + self.q_norm = JambaRMSNorm(self.kq_head_dim) + elif self.config.kq_norm == "perhead-rms": + self.k_norm = PerheadJambaRMSNorm(self.kq_head_dim, self.num_key_value_heads) + self.q_norm = PerheadJambaRMSNorm(self.kq_head_dim, self.num_heads) + elif self.config.kq_norm == "none": + self.k_norm = None + self.q_norm = None + else: + raise NotImplementedError(f"Unknown kq_norm: {self.config.kq_norm}") + + if self.config.rope: + # print("===> Using Rotary Position Embedding") + self._init_rope() + + self.use_linear_attn = use_linear_attn + if self.use_linear_attn: + self.linear_attn_qk_act = F.silu + self.linear_attn_norm = JambaRMSNorm(self.v_head_dim * self.num_heads, eps=config.rms_norm_eps) + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['qk_norm']): + self.sqk_init_value = 1.0 + self.sqk_init_scaling = 1.0 / self.config.hidden_size ** 0.5 + self.sqk = torch.nn.Parameter(self.sqk_init_scaling*torch.ones(self.config.hidden_size, dtype=torch.float32)) + + + def justnorm(self, x): + #return F.normalize(x, p=2, dim=-1) + res = x / x.norm(p=2, dim=-1, keepdim=True) + return res + + def nGPT_qk_norm(self, query_states, key_states, flash_attn=False): + if flash_attn: + sqk = (self.sqk * (self.sqk_init_value/self.sqk_init_scaling)).view(1, 1, self.config.num_attention_heads, self.config.hidden_size // self.config.num_attention_heads) + else: + sqk = (self.sqk * (self.sqk_init_value/self.sqk_init_scaling)).view(1, self.config.num_attention_heads, 1, self.config.hidden_size // self.config.num_attention_heads) + query_states = sqk * self.justnorm(query_states) + key_states = sqk * self.justnorm(key_states) + + return query_states, key_states + + + def _init_rope(self): + # assert 1==0, f"max_position_embeddings: {self.max_position_embeddings}" + self.rotary_emb = LlamaRotaryEmbedding( + config=self.config, + dim=self.kq_head_dim, + base=self.rope_theta, + device=torch.device("cuda"), + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + kv_last_layer = None, + # kv_proj_last_layer = None, + use_swa=False, + query_states = None, + key_states=None, + value_states=None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + raise NotImplementedError("JambaAttention is an abstract class. Use one of the subclasses.") + + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba +class JambaFlashAttention2(JambaAttention): + """ + Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + kv_last_layer=None, + # kv_proj_last_layer = None, + use_swa=False, + query_states = None, + key_states=None, + value_states=None, + **kwargs, + ): + + + + # print(f"Flash Attn - layer_idx: {self.layer_idx}, attn_mask is none: {attention_mask is None}") + # print(f"layer_idx: {self.layer_idx}, use_swq: {use_swa}") + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + if self.attn_only_wo_proj: + assert query_states is not None + bsz, q_len, _ = query_states.size() + else: + bsz, q_len, _ = hidden_states.size() + + if not self.attn_only_wo_proj: + query_states = self.q_proj(hidden_states) + + if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: + query_states = query_states / self.q_proj.weight.norm(p=2, dim=1) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.kq_head_dim).transpose(1, 2).contiguous() + + if self.q_norm is not None: + query_states = self.q_norm(query_states) + + # we do kq_norm first before rope according to + # https://github.com/huggingface/transformers/blob/6c1d0b069de22d7ed8aa83f733c25045eea0585d/src/transformers/models/cohere/modeling_cohere.py#L568 + if self.config.rope: + if self.attn_only_wo_proj: + cos, sin = self.rotary_emb(query_states, position_ids) + else: + cos, sin = self.rotary_emb(hidden_states, position_ids) + query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + + + if self.reuse_kv: + assert kv_last_layer is not None + key_states, value_states = kv_last_layer # (batch, num_heads, slen, head_dim) + + else: + if not self.attn_only_wo_proj: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: + key_states = key_states / self.k_proj.weight.norm(p=2, dim=1) + value_states = value_states / self.v_proj.weight.norm(p=2, dim=1) + + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.kq_head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.v_head_dim).transpose(1, 2) + + if self.k_norm is not None: + key_states = self.k_norm(key_states) + + if self.config.rope: + # cos, sin = self.rotary_emb(hidden_states, position_ids) + _, key_states = apply_rotary_pos_emb(None, key_states, cos, sin) + + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None and not self.reuse_kv: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + # and kv_seq_len > (self.config.sliding_window + self.config.num_memory_tokens if self.config.num_memory_tokens > 0 else self.config.sliding_window) + and kv_seq_len > self.config.sliding_window + and use_swa + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + swa_processed_flag = False + if past_key_value is not None and use_cache and not self.reuse_kv: + # if self.reuse_kv: + # for kv_group in self.config.kv_reuse_group: + # if self.layer_idx in kv_group: + # break + # kv_layer_idx = kv_group[0] + # else: + # kv_layer_idx = self.layer_idx + + kv_layer_idx = self.layer_idx + + cache_has_contents = past_key_value.get_seq_length(kv_layer_idx) > 0 + + if ( + getattr(self.config, "sliding_window", None) is not None + # and kv_seq_len > (self.config.sliding_window + self.config.num_memory_tokens if self.config.num_memory_tokens > 0 else self.config.sliding_window) + and kv_seq_len > self.config.sliding_window + and cache_has_contents + and use_swa + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[kv_layer_idx][0] + past_value = past_key_value[kv_layer_idx][1] + + # if self.config.num_memory_tokens > 0: + # # num_fetched_memory_tokens = min(kv_seq_len - self.config.sliding_window, self.config.num_memory_tokens) + # num_fetched_memory_tokens = self.config.num_memory_tokens + + # past_key = torch.cat([past_key[:, :, :num_fetched_memory_tokens, :], past_key[:, :, slicing_tokens:, :]], dim=-2).contiguous() + # past_value = torch.cat([past_value[:, :, :num_fetched_memory_tokens, :], past_value[:, :, slicing_tokens:, :]], dim=-2).contiguous() + + # # print(past_key.shape, past_value.shape) + + # else: + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + ### only keep sliding_window tokens in kv cache: Removed as this will impact the kv_seq_len calculation, resulting in errors for all swa cases + past_key_value.key_cache[kv_layer_idx] = past_key + past_key_value.value_cache[kv_layer_idx] = past_value + + # if past_key.shape[-2] != self.config.sliding_window - 1 and self.config.num_memory_tokens <= 0: + # raise ValueError( + # f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + # f" {past_key.shape}" + # ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + swa_processed_flag = True + + key_states, value_states = past_key_value.update(key_states, value_states, kv_layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states_no_repeat = key_states + value_states_no_repeat = value_states + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # if self.config.visual_attn: + # attn_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) + + # # Apply the attention mask (if provided) + # if attention_mask is not None: + # attn_scores = attn_scores.masked_fill(attention_mask == 0, float('-inf')) + + # else: + # q_len, k_len = attn_scores.size(-2), attn_scores.size(-1) + # causal_mask = torch.tril(torch.ones((q_len, k_len), device=attn_scores.device)).view(1, 1, q_len, k_len) + # attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf')) + + # self.attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1).detach() + + if self.config.visual_attn: + if hasattr(self, 'attn_weights'): + del self.attn_weights + torch.cuda.empty_cache() + + attention_mask_viz = None + q_len = query_states.shape[-2] + k_len = key_states.shape[-2] + + attn_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) + + if self.config.sliding_window is not None and self.config.global_attn_idx is not None and self.layer_idx not in self.config.global_attn_idx: + attention_mask_viz = torch.zeros((q_len, k_len), device=attn_scores.device) + + # Sliding window attention + sliding_window_size = self.config.sliding_window # Assuming you have this defined in config + for i in range(q_len): + start = max(0, i - sliding_window_size) + end = min(k_len, i + sliding_window_size + 1) + attention_mask_viz[i, start:end] = 1 + + # Ensure the first 'num_memory_tokens' are visible to all tokens + num_memory_tokens = self.config.num_memory_tokens + attention_mask_viz[:, :num_memory_tokens] = 1 + + causal_mask = torch.tril(torch.ones((q_len, k_len), device=attn_scores.device)) + attention_mask_viz = attention_mask_viz * causal_mask # Combine sliding window and causal mask + + # Apply the attention mask (if provided) + if attention_mask_viz is not None: + attn_scores = attn_scores.masked_fill(attention_mask_viz == 0, float('-inf')) + + else: + q_len, k_len = attn_scores.size(-2), attn_scores.size(-1) + causal_mask = torch.tril(torch.ones((q_len, k_len), device=attn_scores.device)).view(1, 1, q_len, k_len) + attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf')) + + + if not self.config.visual_entropy: + self.attn_weights_before_softmax = attn_scores.detach() + self.attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1).detach() + + else: ### get entropy only + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1).detach() + + entropy, _ = compute_average_attention_entropy([[attn_weights]]) + self.attn_entropy = np.mean(entropy, axis=-1, keepdims=True)[0,0] + + ##### Theoretical attention calculation ##### + # attn_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) + # head_dim = query_states.size(-1) + # attn_scores = attn_scores / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32)) + + # attention_mask = torch.tril(torch.ones(query_states.shape[-2], query_states.shape[-2])).unsqueeze(0).unsqueeze(0).to(attn_scores) # (1, 1, slen, slen) + # attention_mask = attention_mask.expand(query_states.shape[0], query_states.shape[1], query_states.shape[-2], query_states.shape[-2]) # (batch_size, num_heads, slen, slen) + # attn_scores = attn_scores.masked_fill(attention_mask == 0, float('-inf')) + + # attn_weights = F.softmax(attn_scores, dim=-1) + # attn_output = torch.matmul(attn_weights, value_states) + + # attn_output = attn_output.transpose(1, 2).contiguous() + # attn_output = attn_output.reshape(bsz, q_len, self.v_head_dim * self.num_heads) + + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) # (batch, slen, num_heads, head_dim) + key_states = key_states.transpose(1, 2) # (batch, slen, num_heads, head_dim) + value_states = value_states.transpose(1, 2) # (batch, slen, num_heads, head_dim) + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['qk_norm']): + query_states, key_states = self.nGPT_qk_norm(query_states, key_states, flash_attn=True) + + sqrt_head_dim = (self.config.hidden_size / self.config.num_attention_heads) ** 0.5 + softmax_scale = sqrt_head_dim + else: + softmax_scale = None + + # attention_mask = torch.ones(query_states.shape[0], query_states.shape[1], device=query_states.device) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows and not swa_processed_flag, + softmax_scale=softmax_scale + ) + + v_dim = value_states.shape[-2] * value_states.shape[-1] + attn_output = attn_output.reshape(bsz, q_len, v_dim).contiguous() + + # if self.layer_idx == 0: + # print(attn_output.mean(-1)) + # input() + + if self.attn_only_wo_proj: + return attn_output, (key_states_no_repeat, value_states_no_repeat) + + if not self.compact_gating: + if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: + attn_output = attn_output / self.o_proj.weight.norm(p=2, dim=0) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, (key_states_no_repeat, value_states_no_repeat) + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + + if attention_mask is not None: + if value_states.shape[-1] == query_states.shape[-1] * 2: + value_states1 = value_states[...,:query_states.shape[-1]] + + batch_size = query_states.shape[0] + + query_states1, key_states1, value_states1, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states1, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad1 = flash_attn_varlen_func( + query_states1, + key_states1, + value_states1, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad1 = flash_attn_varlen_func( + query_states1, + key_states1, + value_states1, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output1 = pad_input(attn_output_unpad1, indices_q, batch_size, query_length) + + value_states2 = value_states[...,query_states.shape[-1]:] + + query_states2, key_states2, value_states2, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states2, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad2 = flash_attn_varlen_func( + query_states2, + key_states2, + value_states2, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad2 = flash_attn_varlen_func( + query_states2, + key_states2, + value_states2, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output2 = pad_input(attn_output_unpad2, indices_q, batch_size, query_length) + + attn_output = torch.cat([attn_output1, attn_output2], dim=-1) + + else: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if value_states.shape[-1] == query_states.shape[-1] * 2: + if not use_sliding_windows: + attn_output1 = flash_attn_func( + query_states, + key_states, + value_states[...,:query_states.shape[-1]], + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output2 = flash_attn_func( + query_states, + key_states, + value_states[...,query_states.shape[-1]:], + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = torch.cat([attn_output1, attn_output2], dim=-1) + + else: + attn_output1 = flash_attn_func( + query_states, + key_states, + value_states[...,:query_states.shape[-1]], + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output2 = flash_attn_func( + query_states, + key_states, + value_states[...,query_states.shape[-1]:], + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = torch.cat([attn_output1, attn_output2], dim=-1) + + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + if not self.training and not type(key_layer) == torch.Tensor: ## this is for handling Mamba2 with output type + key_layer = torch.tensor(key_layer.clone()) + value_layer = torch.tensor(value_layer.clone()) + query_layer = torch.tensor(query_layer.clone()) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + + + +# class JambaFused_MHA(JambaAttention): +# """ +# Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays +# untouched. The only required change would be on the forward pass where it needs to correctly call the public API of +# flash attention and deal with padding tokens in case the input contains any of them. +# """ + +# # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) + +# from .fused_mha_with_cache import fused_mha_interface + +# self.fused_mha_interface = fused_mha_interface + +# self.init_kv_cache(max_batch_size=1, max_seq_len=8000) + + +# def init_kv_cache(self, max_batch_size, max_seq_len, page_size=-1): +# if page_size is not None and page_size > 0: +# batch_max_pages = (max_seq_len + page_size - 1) // page_size +# cache_max_pages = (max_batch_size * max_seq_len + page_size - 1) // page_size +# self.k_cache = torch.zeros(cache_max_pages, page_size, self.num_heads, self.kq_head_dim).to(self.q_proj.weight) +# self.v_cache = torch.zeros(cache_max_pages, page_size, self.num_heads, self.v_head_dim).to(self.q_proj.weight) + +# self.page_table = torch.zeros(max_batch_size, batch_max_pages, device=self.q_proj.weight.device, dtype=torch.int32) +# else: +# self.k_cache = torch.zeros(max_batch_size, max_seq_len, self.num_heads, self.kq_head_dim).to(self.q_proj.weight) +# self.v_cache = torch.zeros(max_batch_size, max_seq_len, self.num_heads, self.v_head_dim).to(self.q_proj.weight) + +# self.page_table = None + +# self.max_seq_len = max_seq_len + + +# def forward( +# self, +# hidden_states: torch.Tensor = None, +# attention_mask: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.LongTensor] = None, +# past_key_value: Optional[Cache] = None, +# output_attentions: bool = False, +# use_cache: bool = False, +# kv_last_layer=None, +# # kv_proj_last_layer = None, +# use_swa=False, +# query_states = None, +# key_states=None, +# value_states=None, +# **kwargs, +# ): + +# # print(f"Flash Attn - layer_idx: {self.layer_idx}, attn_mask is none: {attention_mask is None}") +# # print(f"layer_idx: {self.layer_idx}, use_swq: {use_swa}") +# if "padding_mask" in kwargs: +# warnings.warn( +# "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" +# ) + +# # overwrite attention_mask with padding_mask +# attention_mask = kwargs.pop("padding_mask") + +# if self.attn_only_wo_proj: +# assert query_states is not None +# bsz, q_len, _ = query_states.size() +# else: +# bsz, q_len, _ = hidden_states.size() + +# if not self.attn_only_wo_proj: +# query_states = self.q_proj(hidden_states) + +# if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: +# query_states = query_states / self.q_proj.weight.norm(p=2, dim=1) + +# query_states = query_states.view(bsz, q_len, self.num_heads, self.kq_head_dim).transpose(1, 2).contiguous() + +# if self.q_norm is not None: +# query_states = self.q_norm(query_states) + +# # we do kq_norm first before rope according to +# # https://github.com/huggingface/transformers/blob/6c1d0b069de22d7ed8aa83f733c25045eea0585d/src/transformers/models/cohere/modeling_cohere.py#L568 +# if self.config.rope: +# if self.attn_only_wo_proj: +# cos, sin = self.rotary_emb(query_states, position_ids) +# else: +# cos, sin = self.rotary_emb(hidden_states, position_ids) +# query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + + +# if self.reuse_kv: +# assert kv_last_layer is not None +# key_states, value_states = kv_last_layer # (batch, num_heads, slen, head_dim) + +# else: +# if not self.attn_only_wo_proj: +# key_states = self.k_proj(hidden_states) +# value_states = self.v_proj(hidden_states) + +# if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: +# key_states = key_states / self.k_proj.weight.norm(p=2, dim=1) +# value_states = value_states / self.v_proj.weight.norm(p=2, dim=1) + +# key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.kq_head_dim).transpose(1, 2) +# value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.v_head_dim).transpose(1, 2) + +# if self.k_norm is not None: +# key_states = self.k_norm(key_states) + +# if self.config.rope: +# # cos, sin = self.rotary_emb(hidden_states, position_ids) +# _, key_states = apply_rotary_pos_emb(None, key_states, cos, sin) + + +# kv_seq_len = key_states.shape[-2] +# if past_key_value is not None and not self.reuse_kv: +# if self.layer_idx is None: +# raise ValueError( +# f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " +# "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " +# "with a layer index." +# ) +# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + +# use_sliding_windows = ( +# _flash_supports_window_size +# and getattr(self.config, "sliding_window", None) is not None +# # and kv_seq_len > (self.config.sliding_window + self.config.num_memory_tokens if self.config.num_memory_tokens > 0 else self.config.sliding_window) +# and kv_seq_len > self.config.sliding_window +# and use_swa +# ) + +# key_states_no_repeat = key_states +# value_states_no_repeat = value_states + +# key_states = repeat_kv(key_states, self.num_key_value_groups) +# value_states = repeat_kv(value_states, self.num_key_value_groups) +# dropout_rate = 0.0 if not self.training else self.attention_dropout + +# # Reashape to the expected shape for Flash Attention +# query_states = query_states.transpose(1, 2) # (batch, slen, num_heads, head_dim) +# key_states = key_states.transpose(1, 2) # (batch, slen, num_heads, head_dim) +# value_states = value_states.transpose(1, 2) # (batch, slen, num_heads, head_dim) + +# if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['qk_norm']): +# query_states, key_states = self.nGPT_qk_norm(query_states, key_states, flash_attn=True) + +# sqrt_head_dim = (self.config.hidden_size / self.config.num_attention_heads) ** 0.5 +# softmax_scale = sqrt_head_dim +# else: +# softmax_scale = None + +# # attention_mask = torch.ones(query_states.shape[0], query_states.shape[1], device=query_states.device) + +# if self.k_cache.device != query_states.device: +# self.k_cache = self.k_cache.to(query_states) +# self.v_cache = self.v_cache.to(query_states) + +# attn_output = self.fused_mha_interface( +# query_states, +# key_states, +# value_states, +# k_cache=self.k_cache, +# v_cache=self.v_cache, +# page_table=self.page_table, +# max_seq_len=self.max_seq_len +# ) + +# v_dim = value_states.shape[-2] * value_states.shape[-1] +# attn_output = attn_output.reshape(bsz, q_len, v_dim).contiguous() + +# # if self.layer_idx == 0: +# # print(attn_output.mean(-1)) +# # input() + +# if self.attn_only_wo_proj: +# return attn_output, (key_states_no_repeat, value_states_no_repeat) + +# if not self.compact_gating: +# if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: +# attn_output = attn_output / self.o_proj.weight.norm(p=2, dim=0) + +# attn_output = self.o_proj(attn_output) + +# if not output_attentions: +# attn_weights = None + +# return attn_output, attn_weights, past_key_value, (key_states_no_repeat, value_states_no_repeat) + + + +JAMBA_ATTENTION_CLASSES = { + "flash_attention_2": JambaFlashAttention2, + # "fused_mha": JambaFused_MHA, +} + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class JambaMambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: JambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + + if config.mamba_latent_size is not None: + self.intermediate_size = config.mamba_latent_size + else: + self.intermediate_size = int(config.mamba_expand * config.hidden_size) + + self.time_step_rank = config.mamba_dt_rank + self.use_conv_bias = config.mamba_conv_bias + self.use_bias = config.mamba_proj_bias + + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=self.use_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.intermediate_size, + padding=self.conv_kernel_size - 1, + ) + + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.apply_inner_layernorms = config.mamba_inner_layernorms + + self.use_fast_kernels = config.use_mamba_kernels + + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias) + + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + self.A_log = nn.Parameter(torch.log(A)) + + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if self.apply_inner_layernorms: + self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps) + self.B_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.C_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + else: + self.dt_layernorm = None + self.B_layernorm = None + self.C_layernorm = None + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config" + ) + + + self.mamba_multihead_config = config.mamba_multihead_config + if self.mamba_multihead_config is not None: + assert self.mamba_multihead_config['alpha_mode'] == 'sparsity' or self.mamba_multihead_config['alpha_mode'] == 'cummax' + + if self.mamba_multihead_config['alpha_mode'] == 'cummax': + self.learned_dt_scale = nn.Parameter(torch.ones(1)) + + if self.mamba_multihead_config['alpha_mode'] == 'sparsity': + if 'use_learned_thres' in self.mamba_multihead_config and self.mamba_multihead_config['use_learned_thres']: + self.learned_thres = nn.Parameter(torch.zeros(self.intermediate_size)) + self.smooth_factor = self.mamba_multihead_config['smooth_factor'] + self.detach_dt = self.mamba_multihead_config['detach_dt'] + + if 'use_cummax' in self.mamba_multihead_config and self.mamba_multihead_config['use_cummax']: + self.use_cummax = True + self.cummax_lower_bound = self.mamba_multihead_config['cummax_lower_bound'] + else: + self.use_cummax = False + + else: + self.learned_thres = None + self.smooth_factor = None + self.detach_dt = None + + self.sparsity_split = self.mamba_multihead_config['sparsity_split'] + self.sparsity_ratio = self.mamba_multihead_config['sparsity_ratio'] + + + def _apply_layernorms(self, dt, B, C): + if self.dt_layernorm is not None: + dt = self.dt_layernorm(dt) + if self.B_layernorm is not None: + B = self.B_layernorm(B) + if self.C_layernorm is not None: + C = self.C_layernorm(C) + return dt, B, C + + def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: + projected_states = projected_states / self.in_proj.weight.norm(p=2, dim=1, keepdim=True) + + if ( + self.training and cache_params is None and not self.apply_inner_layernorms + ): # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + + else: + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + ) + + # print(hidden_states.shape, use_precomputed_states) + # input() + + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if use_precomputed_states: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + + cache_params.mamba_past_length[self.layer_idx] += seq_len + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + cache_params.mamba_past_length[self.layer_idx] += seq_len + + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + + if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: + ssm_parameters = ssm_parameters / self.x_proj.weight.norm(p=2, dim=1) + + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + time_step, B, C = self._apply_layernorms(time_step, B, C) + + # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel. + # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed + # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized + # linear layers, and requires to call the forward pass directly. + # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)``` + if hasattr(self.dt_proj, "base_layer"): + # In case of LoRA, we need to access the base layer to get the weight + time_proj_bias = self.dt_proj.base_layer.bias + self.dt_proj.base_layer.bias = None + else: + time_proj_bias = self.dt_proj.bias + self.dt_proj.bias = None + discrete_time_step = self.dt_proj(time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] + if hasattr(self.dt_proj, "base_layer"): + self.dt_proj.base_layer.bias = time_proj_bias + else: + self.dt_proj.bias = time_proj_bias + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None + + + if self.mamba_multihead_config is not None and self.mamba_multihead_config['alpha_mode'] == 'cummax': ### todo: implement this in the fused kernel + discrete_time_step = discrete_time_step.transpose(1, 2) # [batch, seq_len, intermediate_size] + + if time_proj_bias is not None: + discrete_time_step = discrete_time_step + time_proj_bias + + discrete_time_step = torch.nn.functional.softmax(discrete_time_step, dim=-1) + discrete_time_step = torch.cumsum(discrete_time_step, dim=-1) + discrete_time_step = discrete_time_step * self.learned_dt_scale + + discrete_time_step = discrete_time_step.transpose(1, 2).to(hidden_states) # [batch, intermediate_size, seq_len] + + time_proj_bias = None + + if self.mamba_multihead_config is not None and self.mamba_multihead_config['alpha_mode'] == 'sparsity': + discrete_time_step = discrete_time_step.transpose(1, 2) # [batch, seq_len, intermediate_size] + + if time_proj_bias is not None: + discrete_time_step = discrete_time_step + time_proj_bias + + if self.learned_thres is not None: + discrete_time_step = self.sparsify_learned_thres(discrete_time_step) + else: + discrete_time_step = self.split_and_sparsify(discrete_time_step, self.sparsity_split, self.sparsity_ratio) + + discrete_time_step = discrete_time_step.transpose(1, 2).to(hidden_states) # [batch, intermediate_size, seq_len] + + time_proj_bias = None + + + if use_precomputed_states: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + compute_attn_mat = hasattr(self.config, 'compute_attn_mat') and self.config.compute_attn_mat + + outputs = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + delta_bias=time_proj_bias, + delta_softplus=True, + return_last_state=True, + compute_attn_mat=compute_attn_mat + ) + + if compute_attn_mat: + if len(outputs) == 4: + scan_outputs, attn_mat, ssm_state, _ = outputs + else: + scan_outputs, attn_mat, ssm_state = outputs + + if not self.config.visual_entropy: + setattr(self, f'attn_mat_0', attn_mat.detach()) + setattr(self, f'delta_0', nn.functional.softplus(discrete_time_step.transpose(1,2) + time_proj_bias)) + else: + attn_mat = attn_mat.abs() / (attn_mat.abs().sum(dim=-1, keepdim=True) + 1e-6) + entropy, _ = compute_average_attention_entropy([[attn_mat]]) + self.ssm_entropy = entropy[0,0] + else: + if len(outputs) == 3: + scan_outputs, ssm_state, _ = outputs + else: + scan_outputs, ssm_state = outputs + + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_outputs = scan_outputs.transpose(1, 2) + + if self.config.use_nGPT and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: + scan_outputs = scan_outputs / self.out_proj.weight.norm(p=2, dim=0) + + contextualized_states = self.out_proj(scan_outputs) + + return contextualized_states + + + def sparsify_learned_thres(self, dt): + """ + Args: + dt: Tensor of shape [bs, seq_len, nheads] + Returns: + pruned_dt: Pruned tensor with the same shape as dt + """ + # Compute sigmoid scores + + if self.use_cummax: + learned_thres = torch.nn.functional.softmax(self.learned_thres, dim=-1) + learned_thres = torch.cumsum(learned_thres, dim=-1) - self.cummax_lower_bound ## keep the dt_normalized larger than 1 - self.cummax_lower_bound + + dt_normalized = (dt - dt.min(dim=-1, keepdim=True)[0]) / (dt.max(dim=-1, keepdim=True)[0] - dt.min(dim=-1, keepdim=True)[0]) + + scores = torch.sigmoid((dt_normalized.detach() - self.learned_thres) / self.smooth_factor) + + else: + if self.detach_dt: + scores = torch.sigmoid((dt.detach() - self.learned_thres) / self.smooth_factor) + else: + scores = torch.sigmoid((dt - self.learned_thres) / self.smooth_factor) + + # Generate binary mask for pruning (forward pass) + mask = (scores >= 0.5).float() + + # Apply mask in the forward pass and backward using sigmoid + pruned_dt = (dt * mask - dt * scores).detach() + dt * scores + + # print(pruned_dt.mean()) + + return pruned_dt + + + def split_and_sparsify(self, dt, sparsity_split, sparsity_ratio): + """ + dt: a torch.Tensor of shape [bs, seq_len, dim] + sparsity_split: list of ratios (e.g., [0.4, 0.3, 0.3]) that sum to 1 + and define how to split dt along the last dimension + sparsity_ratio: list of ratios (e.g., [0.2, 0.5, 0.3]) that sum to 1 + and define how many time steps (along seq_len) to keep + """ + bs, seq_len, dim = dt.shape + + assert sum(sparsity_split) == 1 + + # Compute the exact split sizes (watching out for integer rounding) + split_sizes = [int(r * dim) for r in sparsity_split] + # Fix potential off-by-one rounding in the last split + split_sizes[-1] = dim - sum(split_sizes[:-1]) + + # Split the original tensor along the last dimension + splitted_tensors = torch.split(dt, split_sizes, dim=-1) + + results = [] + for i, sub_tensor in enumerate(splitted_tensors): + # sub_tensor has shape [bs, seq_len, split_dim_i] + k = int(sparsity_ratio[i] * seq_len) + + ### Strategy 1: keep at least one token + k = max(k, 1) + + ### Strategy 2: the #tokens is the same as training + # if self.config.orig_max_position_embeddings is not None: + # k = int(self.config.orig_max_position_embeddings * self.sparsity_ratio[i]) + # else: + # assert self.config.max_position_embeddings is not None + # k = int(self.config.max_position_embeddings * self.sparsity_ratio[i]) + + # k = min(seq_len, k) + + # print(self.config.max_position_embeddings, sparsity_ratio[i], seq_len, k) + + # 1) Average over the feature dimension (the last dim), + # resulting in shape [bs, seq_len] + averaged_values = sub_tensor.mean(dim=-1) + + # 2) Get top-k indices (along seq_len = dim=1) + topk_values, _ = torch.topk(averaged_values, k=k, dim=1) + # The smallest value among the top-k per batch element + threshold = topk_values[:, -1].unsqueeze(-1) # shape [bs, 1] + + # 3) Create a mask of shape [bs, seq_len] => True if >= threshold + averaged_mask = (averaged_values >= threshold) + + # 4) Expand that mask back to [bs, seq_len, split_dim_i] + mask_3d = averaged_mask.unsqueeze(-1).expand_as(sub_tensor) + + # 5) Zero out everything that is not in top-k + sparsified_sub = sub_tensor * mask_3d + + # print((sparsified_sub == 0).float().mean().item()) + # input() + + results.append(sparsified_sub) + + # Concatenate the results back along the last dimension + output = torch.cat(results, dim=-1) + return output + + # fmt: off + def slow_forward(self, input_states, cache_params: MambaCacheParams = None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + if cache_params is not None: + if self.training: + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + else: + ssm_state = cache_params.ssm_states[self.layer_idx] + + if cache_params.seqlen_offset > 0: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, -1] = hidden_states[:, :, 0] + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding + else: + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + time_step, B, C = self._apply_layernorms(time_step, B, C) + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state_size] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + + def mixer_forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None): + if self.use_fast_kernels: + if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type: + raise ValueError( + "Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device" + ) + return self.cuda_kernels_forward(hidden_states, cache_params) + return self.slow_forward(hidden_states, cache_params) + + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + + res = self.mixer_forward(hidden_states, cache_params=past_key_value) + + return res, past_key_value + + + +class JambaMLP(nn.Module): + def __init__(self, config: JambaConfig, layer_idx: int): + super().__init__() + self.config = config + self.act_fn_name = config.mlp_hidden_act + self.act_fn = ACT2FN[self.act_fn_name] + + if config.ffn_expand_ratio is not None: + self.ffn_dim = int(config.ffn_expand_ratio * config.hidden_size) // 128 * 128 + else: + self.ffn_dim = config.intermediate_size + + self.hidden_dim = config.hidden_size + + self.layer_idx = layer_idx + + if self.act_fn_name == "silu": + self.gate_proj = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.down_proj = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.up_proj = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['learned_scaling']): + self.suv_init_value = 1.0 + self.suv_init_scaling = 1.0 + + self.su = torch.nn.Parameter(self.suv_init_scaling*torch.ones(self.ffn_dim, dtype=torch.float32)) + self.sv = torch.nn.Parameter(self.suv_init_scaling*torch.ones(self.ffn_dim, dtype=torch.float32)) + + self.hash_grid = None + + + def forward(self, x): + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['learned_scaling']): + assert self.hash_grid is None, "Not implemented hash_grid + nGPT; please implement" + # assert self.config.repeat_ffn is None, "Not implemented repeat_ffn + nGPT; please implement" + + su = (self.su * ((self.suv_init_value/self.suv_init_scaling) * (self.config.hidden_size ** 0.5))) + u = self.up_proj(x) * su + + sv = (self.sv * ((self.suv_init_value/self.suv_init_scaling) * (self.config.hidden_size ** 0.5))) + v = self.gate_proj(x) * sv + out = self.down_proj(u * self.act_fn(v)) + return out + + elif self.config.use_nGPT and self.config.nGPT_config is not None and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: + assert self.hash_grid is None, "Not implemented hash_grid + nGPT; please implement" + # assert self.config.repeat_ffn is None, "Not implemented repeat_ffn + nGPT; please implement" + + if self.act_fn_name == "silu": + gate_output = self.gate_proj(x) + gate_output = gate_output / self.gate_proj.weight.norm(p=2, dim=1) + # print('gate_proj:', self.gate_proj.weight.norm(p=2, dim=1).max(), self.gate_proj.weight.norm(p=2, dim=1).min()) + + up_output = self.up_proj(x) + up_output = up_output / self.up_proj.weight.norm(p=2, dim=1) + # print('up_proj:', self.up_proj.weight.norm(p=2, dim=1).max(), self.up_proj.weight.norm(p=2, dim=1).min()) + + output = self.act_fn(gate_output) * up_output + output = output / self.down_proj.weight.norm(p=2, dim=0) + # print('down_proj:', self.down_proj.weight.norm(p=2, dim=0).max(), self.down_proj.weight.norm(p=2, dim=0).min()) + + output = self.down_proj(output) + + return output + + elif self.act_fn_name == "relu2": + raise NotImplementedError(f"Haven't supported relu2 yet") + else: + raise NotImplementedError(f"No such hidden_act: {self.act_fn_name}") + + else: + if self.act_fn_name == "silu": + output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + elif self.act_fn_name == "relu2": + output = self.down_proj(self.act_fn(self.up_proj(x))) + else: + raise NotImplementedError(f"No such hidden_act: {self.act_fn_name}") + + if self.hash_grid is not None: + if next(self.hash_grid.parameters()).dtype != torch.float32: + self.hash_grid = self.hash_grid.float() + + B, T, C = x.shape # (b, n, hidden size) + hashing_input = x.float().view(B*T, 1, -1) # (b * n, 1, hidden size) + meta_tokens = self.hash_grid(hashing_input) # (b * n, 1, hidden size) + meta_tokens = meta_tokens.to(x).view(B, T, C) + + # print(hashing_input.shape, meta_tokens.shape) + + output = self.mlp_pre_avg_layernorm1(output) + self.mlp_pre_avg_layernorm2(meta_tokens) + + return output + + +# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba +class JambaSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config: JambaConfig, num_experts: int, num_experts_per_tok: int, layer_idx: int): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + + self.layer_idx = layer_idx + + # these values are decided on runtime depending on the layer index + self.num_experts = num_experts + self.top_k = num_experts_per_tok + + if num_experts > 1: + # expert routing + self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + else: + self.router = None + + self.experts = nn.ModuleList([JambaMLP(config, layer_idx=layer_idx) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ """ + if len(hidden_states.shape) == 3: + batch_size, sequence_length, hidden_dim = hidden_states.shape + bs_times_seq_len = batch_size * sequence_length + elif len(hidden_states.shape) == 2: + assert self.num_experts == 1 + bs_times_seq_len, hidden_dim = hidden_states.shape + else: + batch_size, sequence_length, _, hidden_dim = hidden_states.shape + bs_times_seq_len = batch_size * sequence_length + + if self.num_experts == 1: + # in this case we have a single MLP block and don't need to do any routing + final_hidden_states = self.experts[0](hidden_states) + + router_logits = torch.ones( + (bs_times_seq_len, 1), + device=hidden_states.device, + dtype=hidden_states.dtype, + requires_grad=hidden_states.requires_grad, + ) + return final_hidden_states, router_logits + + # in this case we have multiple experts and need to do routing + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.router(hidden_states) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + + + +class JambaAttentionDecoderLayer(nn.Module): + def __init__(self, config: JambaConfig, num_experts: int, layer_idx: int, reuse_kv: bool = False): + super().__init__() + + self.config = config + + self.layer_idx = layer_idx + + self.self_attn = JAMBA_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx, reuse_kv=reuse_kv) + + self.reuse_kv = reuse_kv + + if self.config.intermediate_size > 0: + num_experts_per_tok = config.num_experts_per_tok if num_experts > 1 else 1 + self.moe = JambaSparseMoeBlock(config, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, layer_idx=layer_idx) + else: + self.moe = None + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + self.attn_alpha_init_value = 0.05 + self.attn_alpha_init_scaling = 1.0 / self.config.hidden_size ** 0.5 + self.attn_alpha = torch.nn.Parameter(self.attn_alpha_init_scaling*torch.ones(self.config.hidden_size, dtype=torch.float32)) + + self.mlp_alpha_init_value = 0.05 + self.mlp_alpha_init_scaling = 1.0 / self.config.hidden_size ** 0.5 + self.mlp_alpha = torch.nn.Parameter(self.mlp_alpha_init_scaling*torch.ones(self.config.hidden_size, dtype=torch.float32)) + + self.input_layernorm = None + self.pre_moe_layernorm = None + + else: + self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_moe_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + self.hash_grid = None + + + def justnorm(self, x): + #return F.normalize(x, p=2, dim=-1) + res = x / x.norm(p=2, dim=-1, keepdim=True) + return res + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_raw: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + kv_last_layer = None, + # kv_proj_last_layer = None, + use_swa=False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + ori_b, ori_n = hidden_states.shape[0], hidden_states.shape[1] + + if self.hash_grid is not None and ori_n >= self.config.hash_grid_config['meta_every']: + ## insert meta tokens every several tokens + meta_every = self.config.hash_grid_config['meta_every'] + + trunc_seq_len = math.floor(ori_n / meta_every) * meta_every + + # hidden_states = pad_at_dim(hidden_states, (0, next_seq_len - ori_n), dim = -2, value = 0.) + if trunc_seq_len < ori_n: + hidden_states_orig = hidden_states + hidden_states = hidden_states_orig[:, :trunc_seq_len, :] + hidden_states_remain = hidden_states_orig[:, trunc_seq_len:, :] + + hidden_states = rearrange(hidden_states, 'b (n m) d -> (b n) m d', m = meta_every) # m is the segment length + + if next(self.hash_grid.parameters()).dtype != torch.float32: + self.hash_grid = self.hash_grid.float() + + hashing_input = hidden_states[:,-1:,:].float() # (b * n, 1, hidden size) + meta_tokens = self.hash_grid(hashing_input) # (b * n, num_memory_tokens per segment, hidden size) + meta_tokens = meta_tokens.to(hidden_states) + + hidden_states, mem_packed_shape = pack((hidden_states, meta_tokens), 'b * d') + + hidden_states = rearrange(hidden_states, '(b n) m d -> b (n m) d', b = ori_b) + + if trunc_seq_len < ori_n: + hidden_states = torch.cat([hidden_states, hidden_states_remain], dim = 1) + + # print(hashing_input.shape, meta_tokens.shape, hidden_states.shape) + + if position_ids is not None and position_ids.shape[1] != hidden_states.shape[1]: + position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + + residual = hidden_states + + if self.input_layernorm is not None: + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value, current_kv = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + kv_last_layer=kv_last_layer if self.reuse_kv else None, + use_swa=use_swa, + ) + + if self.config.save_input_output: + self.saved_input_output = {} + self.saved_input_output['seq_mixer_input'] = residual.clone().detach() + self.saved_input_output['seq_mixer_output'] = hidden_states.clone().detach() + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + lr = self.attn_alpha * (self.attn_alpha_init_value / self.attn_alpha_init_scaling) + lr = torch.abs(lr) + + residual = self.justnorm(residual) + hidden_states = self.justnorm(hidden_states) + + hidden_states = residual + lr * (hidden_states - residual) + hidden_states = self.justnorm(hidden_states) + + else: + hidden_states = residual + hidden_states + + ## remove meta tokens + if self.hash_grid is not None and not self.config.hash_grid_config['global'] and ori_n >= self.config.hash_grid_config['meta_every']: + if trunc_seq_len < ori_n: + split_len = int(ori_n // self.config.hash_grid_config['meta_every'] * (self.config.hash_grid_config['n_meta_tokens'] + self.config.hash_grid_config['meta_every'])) + hidden_states_orig = hidden_states + hidden_states = hidden_states_orig[:, :split_len, :] + hidden_states_remain = hidden_states_orig[:, split_len:, :] + + hidden_states = rearrange(hidden_states, 'b (n m) d -> (b n) m d', m = (self.config.hash_grid_config['n_meta_tokens'] + self.config.hash_grid_config['meta_every'])) + + hidden_states, _ = unpack(hidden_states, mem_packed_shape, 'b * d') + + hidden_states = rearrange(hidden_states, '(b n) m d -> b (n m) d', b = ori_b) + + if trunc_seq_len < ori_n: + hidden_states = torch.cat([hidden_states, hidden_states_remain], dim = 1) + assert hidden_states.shape[1] == ori_n + + if self.moe is not None: + if self.config.repeat_ffn is not None: + for _ in range(self.config.repeat_ffn): + residual = hidden_states + if self.pre_moe_layernorm is not None: + hidden_states = self.pre_moe_layernorm(hidden_states) + hidden_states, router_logits = self.moe(hidden_states) + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + lr = self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling) + lr = torch.abs(lr) + + residual = self.justnorm(residual) + hidden_states = self.justnorm(hidden_states) + + hidden_states = residual + lr * (hidden_states - residual) + hidden_states = self.justnorm(hidden_states) + + else: + hidden_states = residual + hidden_states + + else: + residual = hidden_states + if self.pre_moe_layernorm is not None: + hidden_states = self.pre_moe_layernorm(hidden_states) + hidden_states, router_logits = self.moe(hidden_states) + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + lr = self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling) + lr = torch.abs(lr) + + residual = self.justnorm(residual) # normally, normalization is not needed + hidden_states = self.justnorm(hidden_states) + + hidden_states = residual + lr * (hidden_states - residual) + hidden_states = self.justnorm(hidden_states) + + else: + hidden_states = residual + hidden_states + else: + router_logits = None + + if self.config.save_input_output: + self.saved_input_output['moe_input'] = residual.clone().detach() + self.saved_input_output['moe_output'] = hidden_states.clone().detach() + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + outputs += (current_kv,) + + return outputs + + + +class FFNDecoderLayer(nn.Module): + def __init__(self, config: JambaConfig, num_experts: int, layer_idx: int, reuse_kv: bool = False): + super().__init__() + + self.config = config + + self.layer_idx = layer_idx + + num_experts_per_tok = config.num_experts_per_tok if num_experts > 1 else 1 + self.moe = JambaSparseMoeBlock(config, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, layer_idx=layer_idx) + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + self.attn_alpha_init_value = 0.05 + self.attn_alpha_init_scaling = 1.0 / self.config.hidden_size ** 0.5 + self.attn_alpha = torch.nn.Parameter(self.attn_alpha_init_scaling*torch.ones(self.config.hidden_size, dtype=torch.float32)) + + self.mlp_alpha_init_value = 0.05 + self.mlp_alpha_init_scaling = 1.0 / self.config.hidden_size ** 0.5 + self.mlp_alpha = torch.nn.Parameter(self.mlp_alpha_init_scaling*torch.ones(self.config.hidden_size, dtype=torch.float32)) + + self.pre_moe_layernorm = None + + else: + self.pre_moe_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + def justnorm(self, x): + #return F.normalize(x, p=2, dim=-1) + res = x / x.norm(p=2, dim=-1, keepdim=True) + return res + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_raw: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + kv_last_layer = None, + # kv_proj_last_layer = None, + use_swa=False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + if self.config.repeat_ffn is not None: + for _ in range(self.config.repeat_ffn): + residual = hidden_states + if self.pre_moe_layernorm is not None: + hidden_states = self.pre_moe_layernorm(hidden_states) + hidden_states, router_logits = self.moe(hidden_states) + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + lr = self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling) + lr = torch.abs(lr) + + residual = self.justnorm(residual) + hidden_states = self.justnorm(hidden_states) + + hidden_states = residual + lr * (hidden_states - residual) + hidden_states = self.justnorm(hidden_states) + + else: + hidden_states = residual + hidden_states + + else: + residual = hidden_states + if self.pre_moe_layernorm is not None: + hidden_states = self.pre_moe_layernorm(hidden_states) + hidden_states, router_logits = self.moe(hidden_states) + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + lr = self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling) + lr = torch.abs(lr) + + residual = self.justnorm(residual) # normally, normalization is not needed + hidden_states = self.justnorm(hidden_states) + + hidden_states = residual + lr * (hidden_states - residual) + hidden_states = self.justnorm(hidden_states) + + else: + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (None,) + + if use_cache: + outputs += (None,) + + if output_router_logits: + outputs += (router_logits,) + + outputs += (kv_last_layer,) + + return outputs + + + +class JambaMambaDecoderLayer(nn.Module): + def __init__(self, config: JambaConfig, num_experts: int, layer_idx: int, reuse_kv: bool = False): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + if config.use_mamba2: + if config.fused_multihead_config is not None: + config.attn_op = JAMBA_ATTENTION_CLASSES[config.attn_implementation] + Mamba_OP = Mamba2_Fused + elif config.mamba_multihead_config is not None: + if config.mamba_multihead_config['alpha_mode'] in ['sparsity', 'cummax']: + Mamba_OP = Mamba2 + else: + Mamba_OP = Mamba2_Multihead + else: + Mamba_OP = Mamba2 + + else: + Mamba_OP = JambaMambaMixer + + self.Mamba_OP = Mamba_OP + + self.reuse_kv = False + self.mamba = Mamba_OP(config=config, layer_idx=layer_idx) + + + self.intermediate_size = config.intermediate_size + if self.intermediate_size > 0: + num_experts_per_tok = config.num_experts_per_tok if num_experts > 1 else 1 + self.moe = JambaSparseMoeBlock(config, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, layer_idx=layer_idx) + + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + self.attn_alpha_init_value = 0.05 + self.attn_alpha_init_scaling = 1.0 / self.config.hidden_size ** 0.5 + self.attn_alpha = torch.nn.Parameter(self.attn_alpha_init_scaling*torch.ones(self.config.hidden_size, dtype=torch.float32)) + + self.mlp_alpha_init_value = 0.05 + self.mlp_alpha_init_scaling = 1.0 / self.config.hidden_size ** 0.5 + self.mlp_alpha = torch.nn.Parameter(self.mlp_alpha_init_scaling*torch.ones(self.config.hidden_size, dtype=torch.float32)) + + self.input_layernorm = None + self.pre_moe_layernorm = None + else: + self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if self.intermediate_size > 0: + self.pre_moe_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.pre_moe_layernorm = None + + self.hash_grid = None + + self.meta_added_flag = False + + + def justnorm(self, x): + #return F.normalize(x, p=2, dim=-1) + res = x / x.norm(p=2, dim=-1, keepdim=True) + return res + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_raw: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + kv_last_layer = None, + use_swa=False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + ori_b, ori_n = hidden_states.shape[0], hidden_states.shape[1] + if self.hash_grid is not None and ori_n >= self.config.hash_grid_config['meta_every']: + ## insert meta tokens every several tokens + meta_every = self.config.hash_grid_config['meta_every'] + + trunc_seq_len = math.floor(ori_n / meta_every) * meta_every + + # hidden_states = pad_at_dim(hidden_states, (0, next_seq_len - ori_n), dim = -2, value = 0.) + if trunc_seq_len < ori_n: + hidden_states_orig = hidden_states + hidden_states = hidden_states_orig[:, :trunc_seq_len, :] + hidden_states_remain = hidden_states_orig[:, trunc_seq_len:, :] + + hidden_states = rearrange(hidden_states, 'b (n m) d -> (b n) m d', m = meta_every) # m is the segment length + + if next(self.hash_grid.parameters()).dtype != torch.float32: + self.hash_grid = self.hash_grid.float() + + hashing_input = hidden_states[:,-1:,:].float() # (b * n, 1, hidden size) + meta_tokens = self.hash_grid(hashing_input) # (b * n, num_memory_tokens per segment, hidden size) + meta_tokens = meta_tokens.to(hidden_states) + + hidden_states, mem_packed_shape = pack((hidden_states, meta_tokens), 'b * d') + + hidden_states = rearrange(hidden_states, '(b n) m d -> b (n m) d', b = ori_b) + + if trunc_seq_len < ori_n: + hidden_states = torch.cat([hidden_states, hidden_states_remain], dim = 1) + + if position_ids is not None and position_ids.shape[1] != hidden_states.shape[1]: + position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + + residual = hidden_states + + if self.input_layernorm is not None: + hidden_states = self.input_layernorm(hidden_states) + + + hidden_states, present_key_value = self.mamba( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask + ) + attn_key_value = None + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + lr = self.attn_alpha * (self.attn_alpha_init_value / self.attn_alpha_init_scaling) + lr = torch.abs(lr) + + residual = self.justnorm(residual) + hidden_states = self.justnorm(hidden_states) + + hidden_states = residual + lr * (hidden_states - residual) + hidden_states = self.justnorm(hidden_states) + else: + hidden_states = residual + hidden_states + + + ## remove meta tokens + if self.hash_grid is not None and not self.config.hash_grid_config['global'] and ori_n >= self.config.hash_grid_config['meta_every']: + if trunc_seq_len < ori_n: + split_len = int(ori_n // self.config.hash_grid_config['meta_every'] * (self.config.hash_grid_config['n_meta_tokens'] + self.config.hash_grid_config['meta_every'])) + hidden_states_orig = hidden_states + hidden_states = hidden_states_orig[:, :split_len, :] + hidden_states_remain = hidden_states_orig[:, split_len:, :] + + hidden_states = rearrange(hidden_states, 'b (n m) d -> (b n) m d', m = (self.config.hash_grid_config['n_meta_tokens'] + self.config.hash_grid_config['meta_every'])) + + hidden_states, _ = unpack(hidden_states, mem_packed_shape, 'b * d') + + hidden_states = rearrange(hidden_states, '(b n) m d -> b (n m) d', b = ori_b) + + if trunc_seq_len < ori_n: + hidden_states = torch.cat([hidden_states, hidden_states_remain], dim = 1) + assert hidden_states.shape[1] == ori_n + + + if self.intermediate_size > 0: + residual = hidden_states + + if self.pre_moe_layernorm is not None: + hidden_states = self.pre_moe_layernorm(hidden_states) + + hidden_states, router_logits = self.moe(hidden_states) + + if self.config.save_input_output: + self.saved_input_output['moe_input'] = residual.clone().detach() + self.saved_input_output['moe_output'] = hidden_states.clone().detach() + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + lr = self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling) + lr = torch.abs(lr) + + residual = self.justnorm(residual) # normally, normalization is not needed + hidden_states = self.justnorm(hidden_states) + + hidden_states = residual + lr * (hidden_states - residual) + hidden_states = self.justnorm(hidden_states) + else: + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + # if output_attentions: + # outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + outputs += (attn_key_value,) + + return outputs + + def _get_past_seqlen(self, past_key_value, seqlen): + if past_key_value is None: + return seqlen + past_seqlen = past_key_value.get_seq_length() + + if past_seqlen == 0: + return seqlen + + return past_seqlen + + # if past_key_value.attention_layer_idx is None: + # return seqlen + # if self.mamba.layer_idx < past_key_value.attention_layer_idx: + # return past_seqlen + 1 + # return past_seqlen + + + +class JambaHybridDecoderLayer(nn.Module): + def __init__(self, config: JambaConfig, num_experts: int, layer_idx: int, reuse_kv: bool = False): + super().__init__() + + self.config = config + + self.layer_idx = layer_idx + + if config.use_mamba2: + Mamba_OP = Mamba2 + else: + Mamba_OP = JambaMambaMixer + + if config.hybrid_decoder_layer == 'mamba': + self.mamba = Mamba_OP(config=config, layer_idx=layer_idx) + elif config.hybrid_decoder_layer == 'deltanet': + # from fla.layers.delta_net import DeltaNet + from .delta_net import DeltaNet + import torch._dynamo + torch._dynamo.config.suppress_errors = True + if self.config.other_args is not None: + assert 'expand_v' in self.config.other_args and 'expand_k' in self.config.other_args + self.gla = DeltaNet(hidden_size=config.hidden_size, num_heads=config.num_attention_heads, layer_idx=layer_idx, expand_v=self.config.other_args['expand_v'], expand_k=self.config.other_args['expand_k'], config=self.config) + else: + self.gla = DeltaNet(hidden_size=config.hidden_size, num_heads=config.num_attention_heads, layer_idx=layer_idx, config=self.config) + elif config.hybrid_decoder_layer == 'gated_deltanet': + # from fla.layers.gated_deltanet import GatedDeltaNet + from .gated_deltanet import GatedDeltaNet + import torch._dynamo + torch._dynamo.config.suppress_errors = True + if self.config.other_args is not None: + assert 'num_heads' in self.config.other_args + self.gla = GatedDeltaNet(hidden_size=config.hidden_size, layer_idx=layer_idx, num_heads=self.config.other_args['num_heads'], head_dim=config.hidden_size//config.num_attention_heads, expand_v=1, config=self.config) + else: + self.gla = GatedDeltaNet(hidden_size=config.hidden_size, layer_idx=layer_idx, num_heads=config.num_attention_heads, head_dim=config.hidden_size//config.num_attention_heads, expand_v=1, config=self.config) + else: + raise ValueError(f"No such hybrid_decoder_layer:{config.hybrid_decoder_layer}") + + self.pure_linear_attn = config.pure_linear_attn + self.self_attn_type = config.self_attn_type + if self.pure_linear_attn: + self.self_attn = None + else: + self.self_attn = JAMBA_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx, reuse_kv=reuse_kv) + + self.config = config + self.share_kv = config.share_kv + + self.reuse_kv = reuse_kv + + self.compact_gating = config.compact_gating + + if self.compact_gating: + self.W_G = nn.Parameter(torch.randn(config.hidden_size, config.hidden_size) / config.hidden_size) + self.W_O = nn.Parameter(torch.randn(config.hidden_size, config.hidden_size) / config.hidden_size) + self.swish = lambda x: x * torch.sigmoid(x) + + if self.config.intermediate_size > 0: + num_experts_per_tok = config.num_experts_per_tok if num_experts > 1 else 1 + self.moe = JambaSparseMoeBlock(config, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, layer_idx=layer_idx) + self.pre_moe_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if not config.fully_parallel_jamba else None + else: + self.moe = None + self.pre_moe_layernorm = None + + self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if not self.pure_linear_attn: + self.reduce_method = config.reduce_method + assert self.reduce_method in ["concat", "mean", "reduce_concat"] + if self.reduce_method == "concat": + self.am_merge = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + if self.reduce_method == "reduce_concat": + attn_dim = int(config.hidden_size * config.reduce_attn_ratio) + mamba_dim = config.hidden_size - attn_dim + self.a_reduce = nn.Linear(config.hidden_size, attn_dim, bias=False) + self.m_reduce = nn.Linear(config.hidden_size, mamba_dim, bias=False) + + self.sequential_jamba = config.sequential_jamba + if self.sequential_jamba: + self.sequential_jamba_norm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + else: + self.pre_avg_layernorm1 = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_avg_layernorm2 = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.fully_parallel_jamba = config.fully_parallel_jamba + if self.fully_parallel_jamba: + self.pre_avg_layernorm3 = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + assert not (self.fully_parallel_jamba and self.sequential_jamba) + else: + # assert not config.sequential_jamba and not config.fully_parallel_jamba + pass + + if self.config.layerwise_memory_token: + assert self.config.num_memory_tokens > 0 + self.memory_tokens = nn.Parameter(torch.randn(self.config.num_memory_tokens, self.config.hidden_size)) + else: + self.memory_tokens = None + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_raw: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + kv_last_layer = None, + # kv_proj_last_layer=None, + gla_past_key_values = None, + use_swa=False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + # assert past_key_value is None, "We don't support past_key_value for HybridDecoderLayer yet!" + # assert gla_past_key_values is None, "We don't support gla_past_key_values for HybridDecoderLayer yet!" + + if self.memory_tokens is not None: + hidden_states = hidden_states[:,self.config.num_memory_tokens:,...] + mem = repeat(self.memory_tokens, 'n d -> b n d', b = hidden_states.shape[0]) # prepend the memory to every segment of m by repeating the memory tokens + hidden_states, mem_packed_shape = pack((mem, hidden_states), 'b * d') + + if type(hidden_states) == tuple: + assert self.config.hybrid_decoder_layer in ['rwkv'] + hidden_states, v_first = hidden_states + else: + v_first = None + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + hidden_states_input = hidden_states + + if self.config.hybrid_decoder_layer == 'mamba': + hybrid_op_hidden_states, mamba_present_key_value = self.mamba( + hidden_states=hidden_states, + past_key_value=past_key_value, + ) + + elif self.config.hybrid_decoder_layer in ['gla', 'gated_deltanet', 'deltanet', 'lightning_attn']: + hybrid_op_hidden_states, _, gla_past_key_values = self.gla( + hidden_states=hidden_states, + attention_mask=attention_mask_raw, + past_key_values=gla_past_key_values, + ) + + else: + raise ValueError(f"No such hybrid_decoder_layer:{self.config.hybrid_decoder_layer}") + + + if not self.pure_linear_attn and self.sequential_jamba: + hidden_states = residual + hybrid_op_hidden_states + residual = hidden_states + hidden_states = self.sequential_jamba_norm(hidden_states) + + + if self.self_attn is not None: + if self.self_attn_type is not None: + if self.self_attn_type == 'mamba': + self_attn_hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + ) + + elif self.self_attn_type == 'retention': + self_attn_hidden_states = self.self_attn( + X=hidden_states, + ) + + elif self.self_attn_type in ['gla', 'gated_deltanet', 'deltanet', 'lightning_attn']: + self_attn_hidden_states, _, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask_raw, + past_key_values=gla_past_key_values, + ) + # elif self.self_attn_type in ['rwkv']: + # self_attn_hidden_states, _, _, v_first = self.self_attn( + # hidden_states=hidden_states, + # attention_mask=attention_mask_raw, + # past_key_values=gla_past_key_values, + # v_first=v_first, + # ) + + else: + raise ValueError(f"No such self_attn:{self.self_attn_type}") + + self_attn_weights = self_attn_present_key_value = current_kv = None + + else: + self_attn_hidden_states, self_attn_weights, self_attn_present_key_value, current_kv = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + kv_last_layer=kv_last_layer if self.reuse_kv else None, + use_swa=use_swa, + ) + else: + self_attn_hidden_states = self_attn_weights = self_attn_present_key_value = current_kv = None + + if self.share_kv: + K_present, V_present = current_kv # (batch_size, num_head, seq_len, dim_per_head) + + # bsz, q_len = K_present.shape[0], K_present.shape[2] + # K_present = K_present.transpose(1, 2).contiguous().reshape(bsz, q_len, -1) + # V_present = V_present.transpose(1, 2).contiguous().reshape(bsz, q_len, -1) + # print(K_present.shape, V_present.shape) + else: + K_present, V_present = None, None + + if self.pure_linear_attn: + hidden_states = residual + hybrid_op_hidden_states + + elif self.sequential_jamba: + hidden_states = self_attn_hidden_states + + # residual connection + assert residual.shape == hidden_states.shape + hidden_states = residual + hidden_states + + elif not self.fully_parallel_jamba: + if self.reduce_method == "concat": + hidden_states = self.am_merge(torch.cat([hybrid_op_hidden_states, self_attn_hidden_states], dim=-1)) + elif self.reduce_method == "mean": + hidden_states = (self.pre_avg_layernorm1(hybrid_op_hidden_states) + self.pre_avg_layernorm2(self_attn_hidden_states)) / 2 + elif self.reduce_method == "reduce_concat": + hidden_states = torch.cat([self.a_reduce(self_attn_hidden_states), self.m_reduce(hybrid_op_hidden_states)], dim=-1) + + if self.compact_gating: + hidden_states = (self.swish(hidden_states_input @ self.W_G) * hidden_states) @ self.W_O + + # residual connection + assert residual.shape == hidden_states.shape + hidden_states = residual + hidden_states + + if self.moe is not None: + if not self.pure_linear_attn and self.fully_parallel_jamba: + hidden_states_ffn, router_logits = self.moe(hidden_states) + + hidden_states = (self.pre_avg_layernorm1(hybrid_op_hidden_states) + self.pre_avg_layernorm2(self_attn_hidden_states) + self.pre_avg_layernorm3(hidden_states_ffn)) / 3 + + # residual connection + assert residual.shape == hidden_states.shape + hidden_states = residual + hidden_states + + else: + residual = hidden_states + hidden_states = self.pre_moe_layernorm(hidden_states) + + hidden_states, router_logits = self.moe(hidden_states) + + hidden_states = residual + hidden_states + else: + router_logits = None + + if v_first is not None and self.layer_idx != self.config.num_hidden_layers - 1: + outputs = ((hidden_states, v_first),) + else: + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (self_attn_present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + # if self.config.hybrid_decoder_layer == 'gla': + # outputs += (gla_past_key_values,) + + outputs += (current_kv,) + + + return outputs + + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Jamba +class JambaPreTrainedModel(PreTrainedModel): + config_class = JambaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @staticmethod + def _convert_to_standard_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. have the seqlen as the third dim + also for mamba layers + """ + attn_layer_index = [k.shape == v.shape for k, v in past_key_value].index(True) + seqlen = past_key_value[attn_layer_index][0].shape[2] + standard_past_key_value = () + for k, v in past_key_value: + if k.shape != v.shape: + # mamba layer + # expand doesn't use more memory, so it's fine to do it here + standard_past_key_value += ((k.expand(-1, -1, seqlen, -1), v.expand(-1, -1, seqlen, -1)),) + else: + standard_past_key_value += ((k, v),) + return standard_past_key_value + + @staticmethod + def _convert_to_jamba_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Converts the cache to the format expected by Jamba, i.e. dummy seqlen dimesion with size 1 for mamba layers + """ + jamba_past_key_value = () + for k, v in past_key_value: + if k.shape != v.shape: + # mamba layer + jamba_past_key_value += ((k[:, :, :1, :], v[:, :, :1, :]),) + else: + jamba_past_key_value += ((k, v),) + return jamba_past_key_value + + + + +def shift_zeros_to_front(attention_mask, hidden_states, position_ids): + """ + Move all zero entries in 'attention_mask' to the front of the sequence + and reorder 'hidden_states' accordingly, preserving the order of zeros + and the order of ones. + + Args: + attention_mask: (batch_size, seq_len), values in {0, 1}. + hidden_states: (batch_size, seq_len, dim). + + Returns: + shifted_mask: (batch_size, seq_len) with zeros at the front. + shifted_states: (batch_size, seq_len, dim) reordered accordingly. + """ + B, L = attention_mask.shape + D = hidden_states.shape[-1] + + shifted_mask = torch.empty_like(attention_mask) + shifted_states = torch.empty_like(hidden_states) + shifted_position_ids = torch.empty_like(position_ids) + + # Process each batch row independently + for b in range(B): + row_mask = attention_mask[b] # (seq_len,) + row_states = hidden_states[b] # (seq_len, dim) + row_pos = position_ids[b] # (seq_len,) + + # Find positions of zeros and ones + zero_indices = torch.where(row_mask == 0)[0] + one_indices = torch.where(row_mask == 1)[0] + + # Concatenate zero indices (in order) then one indices + new_order = torch.cat([zero_indices, one_indices], dim=0) + + # Reorder mask and states + shifted_mask[b] = row_mask[new_order] + shifted_states[b] = row_states[new_order] + shifted_position_ids[b] = row_pos[new_order] + + return shifted_mask, shifted_states, shifted_position_ids + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->JAMBA, Mistral->Jamba +class JambaModel(JambaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JambaDecoderLayer`] + + Args: + config: JambaConfig + """ + + def __init__(self, config: JambaConfig): + super().__init__(config) + # assert 1==0, config.attn_implementation + # assert config.attn_implementation == "sdpa", "Xin: only supports sdpa attention for now" + config.attn_implementation = config.attn_implementation_new + config._attn_implementation = config.attn_implementation_new + + self.config = config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.inter_layer_kv_reuse = config.kv_reuse_every_i_layer > 0 or config.kv_reuse_group is not None + self.kv_reuse_group = config.kv_reuse_group + self.kv_reuse_every_i_layer = config.kv_reuse_every_i_layer + + decoder_layers = [] + + hybrid_decoder_op = JambaHybridDecoderLayer + + if self.kv_reuse_group is not None: + self.kv_reuse_group = [{'producer': group[0], 'consumer': group[1:]} for group in self.kv_reuse_group] + + self.use_gla_kv = False + + layer_type = [] + for i in range(config.num_hidden_layers): + # is_attn = True if (i - self.config.attn_layer_offset) % self.config.attn_layer_period == 0 else False # original imple + is_attn = True if (i>=self.config.attn_layer_offset) and ((i - self.config.attn_layer_offset) % self.config.attn_layer_period == 0) else False + is_expert = True if (i - self.config.expert_layer_offset) % self.config.expert_layer_period == 0 else False + + num_experts = self.config.num_experts if is_expert else 1 + + if self.inter_layer_kv_reuse: + if self.kv_reuse_group is not None: + reuse_kv = False + for group_id, item in enumerate(self.kv_reuse_group): + if i in item['consumer']: + reuse_kv = True + + else: + if i % config.kv_reuse_every_i_layer == 0: + reuse_kv = False + else: + reuse_kv = True + else: + reuse_kv = False + + self.moe_config = config.moe_config + + if config.layer_types is not None: + if config.layer_types[i] == 'h': + layer_type.append('h') + decoder_layer = hybrid_decoder_op(config, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + elif config.layer_types[i] in ['gdn', 'deltanet']: + if config.layer_types[i] == 'gdn' or config.layer_types[i] == 'gated_deltanet': + layer_type.append('m') + config_new = copy.deepcopy(config) + config_new.pure_linear_attn = True + config_new.hybrid_decoder_layer = 'gated_deltanet' + decoder_layer = hybrid_decoder_op(config_new, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + elif config.layer_types[i] == 'deltanet': + layer_type.append('m') + config_new = copy.deepcopy(config) + config_new.pure_linear_attn = True + config_new.hybrid_decoder_layer = 'deltanet' + decoder_layer = hybrid_decoder_op(config_new, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + elif config.layer_types[i] == 'm': + layer_type.append('m') + config_new = copy.deepcopy(config) + config_new.intermediate_size = 0 + decoder_layer = JambaMambaDecoderLayer(config_new, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + elif config.layer_types[i] == 'm2': + layer_type.append('m') + config_new = copy.deepcopy(config) + config_new.intermediate_size = 0 + config_new.use_mamba2 = True + decoder_layer = JambaMambaDecoderLayer(config_new, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + elif config.layer_types[i] == 'a': + layer_type.append('a') + decoder_layer = JambaAttentionDecoderLayer(config, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + elif config.layer_types[i] == 'f': + layer_type.append('a') + decoder_layer = FFNDecoderLayer(config, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + else: + raise ValueError(f"Unknown layer type {layer_type}") + + if config.layer_types[i] in ['gdn', 'gated_deltanet', 'deltanet', 'm_orig', 'mamba_orig']: + self.use_gla_kv = True + + else: + if i in config.hybrid_block_indices: + layer_type.append('h') + decoder_layer = hybrid_decoder_op(config, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + + if config.hybrid_decoder_layer in ['gla', 'lightning_attn', 'deltanet', 'gated_deltanet', 'rwkv']: + self.use_gla_kv = True + else: + if is_attn: + layer_type.append('a') + decoder_layer = JambaAttentionDecoderLayer(config, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + else: + layer_type.append('m' if config.fused_multihead_config is None else 'h') + decoder_layer = JambaMambaDecoderLayer(config, num_experts=num_experts, layer_idx=i, reuse_kv=reuse_kv) + + decoder_layers.append(decoder_layer) + config.layer_type = layer_type + + if config.sliding_window is not None: + self.sliding_window = config.sliding_window + self.global_attn_idx = config.global_attn_idx + else: + self.sliding_window = None + self.global_attn_idx = None + + if not any(isinstance(layer, JambaAttentionDecoderLayer) for layer in decoder_layers): + # raise ValueError("At least one layer in the decoder must be an attention layer") + self._attn_layer_index = [] + else: + self._attn_layer_index = [isinstance(layer, JambaAttentionDecoderLayer) for layer in decoder_layers].index( + True + ) + + if not any(isinstance(layer, JambaMambaDecoderLayer) for layer in decoder_layers): + # raise ValueError("At least one layer in the decoder must be a Mamba layer") + self._mamba_layer_index = [] + else: + self._mamba_layer_index = [isinstance(layer, JambaMambaDecoderLayer) for layer in decoder_layers].index(True) + + # if ( + # decoder_layers[self._mamba_layer_index].mamba.ssm_state_size + # == decoder_layers[self._mamba_layer_index].mamba.conv_kernel_size + # ): + # raise ValueError("Mamba state size and convolution size must be different") + + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config.attn_implementation + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + self.final_layernorm = None + else: + self.final_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if self.config.num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(self.config.num_memory_tokens, self.config.hidden_size)) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + # self.rope_cache: Optional[RoPECache] = None + + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], HybridMambaAttentionDynamicCache]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + gla_past_key_values=None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if self.config.mamba_attnaug_config is not None: + seq_length = seq_length // self.config.mamba_attnaug_config['ds_ratio'] + + if use_cache: + # if isinstance(past_key_values, Cache) and not isinstance( + # past_key_values, HybridMambaAttentionDynamicCache + # ): + # past_key_values = HybridMambaAttentionDynamicCache.from_legacy_cache(past_key_values.to_legacy_cache()) + # use_legacy_cache = not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + # if use_legacy_cache: + # past_key_values = HybridMambaAttentionDynamicCache.from_legacy_cache(past_key_values) + + use_legacy_cache = False + # past_key_values_length = past_key_values.get_usable_length(seq_length, self._attn_layer_index) + if past_key_values is not None: + past_key_values_length = past_key_values.get_usable_length(seq_length, 0) + else: + use_cache = False + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + if self.config.num_memory_tokens > 0 and past_key_values is not None and past_key_values.get_seq_length() == 0: + position_ids = position_ids.view(-1, seq_length + self.config.num_memory_tokens).long() + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + ori_b, ori_n = inputs_embeds.shape[0], inputs_embeds.shape[1] + + if self.config.num_memory_tokens > 0 and (past_key_values is None or past_key_values.get_seq_length() == 0): + if self.config.memory_tokens_interspersed_every > 0: + mem_every = self.config.memory_tokens_interspersed_every + next_seq_len = math.ceil(ori_n / mem_every) * mem_every + + # print(f"before padding: {inputs_embeds.shape}") + inputs_embeds = pad_at_dim(inputs_embeds, (0, next_seq_len - ori_n), dim = -2, value = 0.) + # print(f"after padding: {inputs_embeds.shape}") + # assert 1==0 + inputs_embeds = rearrange(inputs_embeds, 'b (n m) d -> (b n) m d', m = mem_every) # m is the segment length + + mem = repeat(self.memory_tokens, 'n d -> b n d', b = inputs_embeds.shape[0]) # prepend the memory to every segment of m by repeating the memory tokens + inputs_embeds, mem_packed_shape = pack((mem, inputs_embeds), 'b * d') + + if self.config.memory_tokens_interspersed_every > 0: + inputs_embeds = rearrange(inputs_embeds, '(b n) m d -> b (n m) d', b = ori_b) + # # removing the last (next_seq_len - n) tokens + # inputs_embeds = inputs_embeds[:, :-next_seq_len + n, :] + + # assert 1==0, f"inputs_embeds shape: {inputs_embeds.shape}" + + if position_ids is not None and position_ids.shape[1] != inputs_embeds.shape[1]: + # print(f"position_ids shape: {position_ids.shape}") + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + # assert 1==0, f"position_ids shape: {position_ids.shape}" + + if inputs_embeds.shape[1] > 1 and attention_mask is not None and (attention_mask == 0).any(): + attention_mask, inputs_embeds, position_ids = shift_zeros_to_front(attention_mask, inputs_embeds, position_ids) + + if attention_mask is not None and attention_mask.shape[1] < inputs_embeds.shape[1]: + assert attention_mask.shape[1] + self.config.num_memory_tokens == inputs_embeds.shape[1] + attention_mask = torch.cat([torch.ones(inputs_embeds.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1) + + attention_mask_raw = attention_mask + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Jamba. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flex": + # assert 1==0, f"attention mask: {attention_mask.shape}, {attention_mask[:10, :10]}" + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask_swa = attention_mask + + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + + # print(attention_mask.shape, inputs_embeds.shape, seq_length, past_key_values_length) + # input() + + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + if self.sliding_window is not None: + attention_mask_swa = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.sliding_window + ) + + # print(f'Attention mask shape: {attention_mask.shape}, {attention_mask[120:130, 120:130]}') + # print(f'Attention swa mask shape: {attention_mask_swa.shape}, {attention_mask_swa[0, 0, 120:130, 120:130]}') + # assert 1==0 + + else: + + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + # sliding_window=self.config.sliding_window, + ) + + + if self.sliding_window is not None: + attention_mask_swa = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.sliding_window + ) + + # print(f'Attention mask shape: {attention_mask.shape}, {attention_mask[120:130, 120:130]}') + # print(f'Attention swa mask shape: {attention_mask_swa.shape}, {attention_mask_swa[120:130, 120:130]}') + # assert 1==0 + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + kv_last_layer = None + # kv_proj_last_layer = None + + shared_kv_cache_dict = {} + + kwargs = {} + if self.use_gla_kv and use_cache: + kwargs['gla_past_key_values'] = gla_past_key_values + + for i, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.inter_layer_kv_reuse and self.kv_reuse_group is not None: + no_reuse_flag = True + for group_id, item in enumerate(self.kv_reuse_group): + if i in item['consumer']: + kv_last_layer = shared_kv_cache_dict[group_id] + no_reuse_flag = False + # print(f'[Layer-{i}]: Reuse KV cache from Layer-{self.kv_reuse_group[group_id]["producer"]}') + break + + if no_reuse_flag: + kv_last_layer = None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask if (self.sliding_window is None or i in self.global_attn_idx) else attention_mask_swa, + attention_mask_raw, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + kv_last_layer, + # kv_proj_last_layer, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask if (self.sliding_window is None or i in self.global_attn_idx) else attention_mask_swa, + attention_mask_raw=attention_mask_raw, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + kv_last_layer=kv_last_layer if self.inter_layer_kv_reuse else None, + use_swa=self.sliding_window is not None and i not in self.global_attn_idx, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[3],) + + if self.inter_layer_kv_reuse: + kv_last_layer = layer_outputs[-1] + + if self.kv_reuse_group is not None: + for group_id, item in enumerate(self.kv_reuse_group): + if i == item['producer']: + shared_kv_cache_dict[group_id] = kv_last_layer + # print(f'[Layer-{i}]: Produce KV for group-{group_id}') + break + + del shared_kv_cache_dict + + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.config.hash_grid_config is not None and self.config.hash_grid_config['global'] and ori_n >= self.config.hash_grid_config['meta_every']: + if 'meta_loss' not in self.config.hash_grid_config or not self.config.hash_grid_config['meta_loss'] or not self.training: + bs = hidden_states.shape[0] + + meta_every = self.config.hash_grid_config['meta_every'] + trunc_seq_len = math.floor(ori_n / meta_every) * meta_every + + if trunc_seq_len < ori_n: + split_len = int(ori_n // self.config.hash_grid_config['meta_every'] * (self.config.hash_grid_config['n_meta_tokens'] + self.config.hash_grid_config['meta_every'])) + hidden_states_orig = hidden_states + hidden_states = hidden_states_orig[:, :split_len, :] + hidden_states_remain = hidden_states_orig[:, split_len:, :] + + hidden_states = rearrange(hidden_states, 'b (n m) d -> (b n) m d', m = (self.config.hash_grid_config['n_meta_tokens'] + self.config.hash_grid_config['meta_every'])) + + hidden_states = hidden_states[:, :hidden_states.shape[1] - self.config.hash_grid_config['n_meta_tokens'], :] + + hidden_states = rearrange(hidden_states, '(b n) m d -> b (n m) d', b = bs) + + if trunc_seq_len < ori_n: + hidden_states = torch.cat([hidden_states, hidden_states_remain], dim = 1) + assert hidden_states.shape[1] == ori_n + + + if self.config.num_memory_tokens > 0 and (past_key_values is None or past_key_values.get_seq_length() == 0): + if self.config.memory_tokens_interspersed_every > 0: + hidden_states = rearrange(hidden_states, 'b (n m) d -> (b n) m d', m = (self.config.num_memory_tokens + self.config.memory_tokens_interspersed_every)) + + mem, hidden_states = unpack(hidden_states, mem_packed_shape, 'b * d') + + if self.config.memory_tokens_interspersed_every > 0: + hidden_states = rearrange(hidden_states, '(b n) m d -> b (n m) d', b = ori_b) + + hidden_states = hidden_states[:, :ori_n, :] + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache if gla_past_key_values is None else (next_cache, gla_past_key_values), + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba +class JambaForCausalLM(JambaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: JambaConfig): + super().__init__(config) + self.config = config + self.model = JambaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + self.use_nGPT = self.config.use_nGPT + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + self.sz_init_value = 1.00 + self.sz_init_scaling = 1.0 / self.config.hidden_size ** 0.5 + self.sz = torch.nn.Parameter(self.sz_init_scaling*torch.ones(self.config.vocab_size, dtype=torch.float32)) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def justnorm(self, x, idim=-1): + dtype = x.dtype + x = x.float() + res = (x / x.norm(p=2, dim=idim, keepdim=True)).to(dtype=dtype) + return res + + def normalize_matrices_cont(self, optimizer=None): + self.lm_head.weight.data.copy_(self.justnorm(self.lm_head.weight.data, 0)) # V, n_embd + + qkv_name = ["q_proj", "k_proj", "v_proj"] + mlp_name = ["gate_proj", "up_proj"] + + qkv_norm_list = [] + mlp_norm_list = [] + + for layer_id in range(self.config.num_hidden_layers): + # Access the layer using indexing instead of dot notation + layer = self.model.layers[layer_id] + + # Get the Q, K, V projection weights and calculate their norm + norm = torch.cat([ + layer.self_attn.q_proj.weight, + layer.self_attn.k_proj.weight, + layer.self_attn.v_proj.weight # Fixed to use v_proj instead of q_proj again + ], dim=0).norm(p=2, dim=0, keepdim=True) + qkv_norm_list.append(norm) + + # Access the first expert in MoE and calculate the norm for up_proj and gate_proj + norm = torch.cat([ + layer.moe.experts[0].up_proj.weight, + layer.moe.experts[0].gate_proj.weight + ], dim=0).norm(p=2, dim=0, keepdim=True) + mlp_norm_list.append(norm) + + qkv_id = 0 + mlp_id = 0 + for name, param in self.model.named_parameters(): + if any(keyword in name for keyword in qkv_name): + # print('Weight norm:', name, qkv_id) + param.data.copy_(param.data/qkv_norm_list[qkv_id]) + if 'v_proj' in name: + qkv_id += 1 + + elif any(keyword in name for keyword in mlp_name): + # print('Weight norm:', name, mlp_id) + param.data.copy_(param.data/mlp_norm_list[mlp_id]) + if 'up_proj' in name: + mlp_id += 1 + + if optimizer is not None: + vocab_size = self.vocab_size + qkv_dim = in_proj_dim = in_proj_dim2 = intermediate_dim = mlp_dim = x_proj_dim = -1 + + for name, param in self.model.layers.named_parameters(): + if 'q_proj' in name: + qkv_dim = param.shape[0] + if 'in_proj' in name: + if in_proj_dim !=-1 and in_proj_dim != param.shape[0]: ## cross-layer kv reuse + in_proj_dim2 = param.shape[0] + else: + in_proj_dim = param.shape[0] + if 'out_proj' in name: + intermediate_dim = param.shape[1] + if 'up_proj' in name: + mlp_dim = param.shape[0] + + # print(qkv_dim, in_proj_dim, in_proj_dim2, intermediate_dim, mlp_dim, x_proj_dim) + + qkvo_cnt = 0 + mlp_cnt = 0 + embed_head_cnt = 0 + + qkv_layer_id = 0 + mlp_layer_id = 0 + + if (hasattr(optimizer, 'shard_fp32_from_float16_groups') == False): # if False, then fused_adam + for param_group in optimizer.param_groups: + for param in param_group['params']: + idim = -1 + if param.shape == (vocab_size, self.config.hidden_size): + if embed_head_cnt == 0: ## embedding, processed above + embed_head_cnt += 1 + else: + idim = 0 # lm_head + param.data.copy_(self.justnorm(param.data, idim)) + # print('Optimizer norm: lm_head') + + if param.shape == (qkv_dim, self.config.hidden_size): # q,k,v proj or o_proj of attn + if qkvo_cnt < 3: # q,k,v proj of attn + qkvo_cnt += 1 + idim = 0 + + qkv_norm = qkv_norm_list[qkv_layer_id] + + param.data.copy_(param.data/qkv_norm) + + # print('Optimizer qkv norm:', qkv_layer_id) + + else: # o_proj of attn + qkvo_cnt = 0 + idim = -1 + + qkv_layer_id += 1 + + if param.shape == (in_proj_dim, self.config.hidden_size) or param.shape == (in_proj_dim2, self.config.hidden_size): # in proj of hymba + idim = 0 + + if param.shape == (self.config.hidden_size, intermediate_dim): idim = -1 # output proj of hymba + if param.shape == (mlp_dim, self.config.hidden_size): + idim = 0 # mlp's up proj / gate proj + + mlp_cnt += 1 + mlp_norm = mlp_norm_list[mlp_layer_id] + + param.data.copy_(param.data/mlp_norm) + + # print('Optimizer mlp norm:', mlp_layer_id) + + if mlp_cnt == 2: + mlp_layer_id += 1 + mlp_cnt = 0 + + if param.shape == (self.config.hidden_size, mlp_dim): idim = -1 # mlp's down proj + if param.shape == (x_proj_dim, intermediate_dim): idim = -1 # x_proj + + # if (idim >= 0): + # param.data.copy_(self.justnorm(param.data, idim)) + # else: + # if len(param.shape) == 2: + # print('Not normalized:', param.shape) + + # print("Finish weight normalization.") + + def normalize_matrices(self, optimizer=None): + if self.config.nGPT_config is not None and 'cont_ngpt' in self.config.nGPT_config and self.config.nGPT_config['cont_ngpt']: + self.normalize_matrices_cont(optimizer=optimizer) + return + + if self.config.nGPT_config is None or self.config.nGPT_config['init_norm']: + self.model.embed_tokens.weight.data.copy_(self.justnorm(self.model.embed_tokens.weight.data, 1)) # V, n_embd + + if optimizer is not None: ## normalize embedding in optimizer + vocab_size = self.vocab_size + + if (hasattr(optimizer, 'shard_fp32_from_float16_groups') == False): # if False, then fused_adam + for param_group in optimizer.param_groups: + for param in param_group['params']: + if param.shape == (vocab_size, self.config.hidden_size): # embedding + param.data.copy_(self.justnorm(param.data, 1)) + break + + if self.config.nGPT_config is None or self.config.nGPT_config['weight_norm']: + self.lm_head.weight.data.copy_(self.justnorm(self.lm_head.weight.data, 1)) # V, n_embd + + if self.config.nGPT_config is not None and 'weight_norm_out' in self.config.nGPT_config and self.config.nGPT_config['weight_norm_out']: + keywords_d0 = ["down_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "in_proj", "q_proj", "k_proj", "v_proj", "x_proj"] + keywords_d1 = [] + else: + keywords_d0 = ["down_proj", "out_proj", "o_proj"] + keywords_d1 = ["gate_proj", "up_proj", "in_proj", "q_proj", "k_proj", "v_proj", "x_proj", "b_proj", "a_proj"] # , "g_proj"] + + # if self.config.nGPT_config is not None and (self.config.nGPT_config['norm_bc'] or self.config.nGPT_config['norm_ssm_input']): + # keywords_d1.append("x_proj") + + for name, param in self.model.named_parameters(): + if any(keyword in name for keyword in keywords_d0): + param.data.copy_(self.justnorm(param.data, 0)) + elif any(keyword in name for keyword in keywords_d1): + param.data.copy_(self.justnorm(param.data, 1)) + # else: + # print(name) + + if optimizer is not None: + vocab_size = self.vocab_size + q_dim = k_dim = v_dim = o_interm_dim = in_proj_dim = in_proj_dim2 = intermediate_dim = mlp_dim = x_proj_dim = b_proj_dim = a_proj_dim = g_proj_dim = -1 + for name, param in self.model.layers.named_parameters(): + if 'q_proj' in name: + q_dim = param.shape[0] + if 'k_proj' in name: + if param.shape[0] != q_dim: ## only keep k_dim for layers with GQA + k_dim = param.shape[0] + if 'v_proj' in name: + if param.shape[0] != q_dim: ## only keep v_dim for layers with GQA + v_dim = param.shape[0] + + if 'in_proj' in name: + if in_proj_dim !=-1 and in_proj_dim != param.shape[0]: ## cross-layer kv reuse + in_proj_dim2 = param.shape[0] + else: + in_proj_dim = param.shape[0] + + if 'o_proj' in name: + o_interm_dim = param.shape[1] + + if 'out_proj' in name: + intermediate_dim = param.shape[1] + + if 'up_proj' in name: + mlp_dim = param.shape[0] + + if 'x_proj' in name: + x_proj_dim = param.shape[0] + + if 'b_proj' in name: + b_proj_dim = param.shape[0] + + if 'a_proj' in name: + a_proj_dim = param.shape[0] + + # if 'g_proj' in name: + # g_proj_dim = param.shape[0] + + if o_interm_dim != -1: + assert o_interm_dim == q_dim, "only support expand_v = 1 => o_interm_dim == q_dim == hidden size for now" + + if k_dim !=-1 and v_dim != -1: + assert v_dim == k_dim, "only support v_dim == k_dim for now" + + qkvo_cnt = 0 + embed_head_cnt = 0 + if (hasattr(optimizer, 'shard_fp32_from_float16_groups') == False): # if False, then fused_adam + for param_group in optimizer.param_groups: + for param in param_group['params']: + idim = -1 + if param.shape == (vocab_size, self.config.hidden_size): + if embed_head_cnt == 0: ## embedding, processed above + embed_head_cnt += 1 + else: + idim = 1 # lm_head + + if param.shape == (q_dim, self.config.hidden_size): # q,k,v proj or o_proj of attn + if qkvo_cnt < 3: # q,k,v proj of attn + qkvo_cnt += 1 + idim = 1 + else: # o_proj of attn + qkvo_cnt = 0 + idim = 0 + + if k_dim != -1: + if param.shape == (k_dim, self.config.hidden_size): # k,v proj, still count qkvo_cnt to handle both GQA/non-GQA + qkvo_cnt += 1 + idim = 1 + + if param.shape == (in_proj_dim, self.config.hidden_size) or param.shape == (in_proj_dim2, self.config.hidden_size): # in proj of hymba + idim = 1 + + if param.shape == (self.config.hidden_size, intermediate_dim): idim = 0 # output proj of hymba + if param.shape == (mlp_dim, self.config.hidden_size): idim = 1 # mlp's up proj / gate proj + if param.shape == (self.config.hidden_size, mlp_dim): idim = 0 # mlp's down proj + if param.shape == (x_proj_dim, intermediate_dim): idim = 1 # x_proj + + if param.shape == (b_proj_dim, self.config.hidden_size): idim = 1 # b_proj + if param.shape == (a_proj_dim, self.config.hidden_size): idim = 1 # a_proj + # if param.shape == (g_proj_dim, self.config.hidden_size): idim = 1 # g_proj + + if (idim >= 0): + param.data.copy_(self.justnorm(param.data, idim)) + + # else: + # if len(param.shape) == 2: + # print('Not normalized:', param.shape) + + # print("Finish weight normalization.") + + + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + calc_logits_for_entire_prompt: Optional[bool] = True, + gla_past_key_values=None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + calc_logits_for_entire_prompt (`bool`, *optional*): + Whether or not to calculate the logits for the entire prompt, or just the last token. Only last token + logits are needed for generation, and calculating them only for that token can save memory, + which becomes pretty significant for long sequences. + + Returns: + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + gla_past_key_values=gla_past_key_values, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if calc_logits_for_entire_prompt: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[..., -1:, :]) + + if self.config.use_nGPT and self.config.nGPT_config is not None and 'extra_grad' in self.config.nGPT_config and self.config.nGPT_config['extra_grad']: + logits = logits / self.lm_head.weight.norm(p=2, dim=1) + # print('lm_head:', self.lm_head.weight.norm(p=2, dim=1).max(), self.lm_head.weight.norm(p=2, dim=1).min()) + + logits = logits.float() + + if self.config.use_nGPT and (self.config.nGPT_config is None or self.config.nGPT_config['post_norm']): + sz = self.sz * (self.sz_init_value/self.sz_init_scaling) + + if sz.shape[-1] != logits.shape[-1]: + padding_size = logits.shape[-1] - sz.shape[-1] + sz = torch.cat([sz, torch.zeros(*sz.shape[:-1], padding_size, device=sz.device)], dim=-1) + + logits = sz * logits + + loss = None + if labels is not None: + if self.config.hash_grid_config is not None and 'meta_loss' in self.config.hash_grid_config and self.config.hash_grid_config['meta_loss'] and self.training: + bs, ori_n = labels.shape + segment_length = self.config.hash_grid_config['meta_every'] + insert_length = self.config.hash_grid_config['n_meta_tokens'] + + # Ensure the sequence length is divisible by the segment length + assert ori_n % segment_length == 0, "Sequence length must be divisible by segment length (64)." + + # Calculate the number of segments + num_segments = ori_n // segment_length + + # Reshape into segments [bs, num_segments, segment_length] + segments = labels.view(bs, num_segments, segment_length) + + # Prepare insertion tokens by slicing the first tokens of the next segment + insertion_tokens = segments[:, 1:, :insert_length] # Tokens from the next segments + insertion_tokens = torch.cat( + [insertion_tokens, torch.zeros(bs, 1, insert_length, dtype=labels.dtype, device=labels.device)], dim=1 + ) # Pad the last segment's insertion with zeros (no next segment) + + # Concatenate the segments with their respective insertion tokens + result_segments = torch.cat([segments, insertion_tokens], dim=2) + + # Reshape back to the original batch sequence format + labels = result_segments.view(bs, -1) + + labels = labels[:, :-self.config.hash_grid_config['n_meta_tokens']] + logits = logits[:, :-self.config.hash_grid_config['n_meta_tokens']] + + # print(segments.shape, insertion_tokens.shape, result_segments.shape) + # print(logits.shape, labels.shape) + # input() + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + # print("hidden_states.shape:", hidden_states.shape, "input_ids.shape:", input_ids.shape, "logits.shape:", logits.shape) + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def get_init_cache(self, batch_size=1): + past_key_values = HybridMambaAttentionDynamicCache( + self.config, batch_size, self.dtype, device=self.device, layer_type=self.config.layer_type + ) + + if self.model.use_gla_kv: + if self.config.layer_types is not None and ('m_orig' in self.config.layer_types or 'mamba_orig' in self.config.layer_types): + from mamba_ssm.utils.generation import InferenceParams + gla_past_key_values = InferenceParams(max_seqlen=8000, max_batch_size=1) ## Note: this is hard-coded; should be updated for general cases + gla_past_key_values.seqlen_offset = 1 ## this is only for measuring decoding speed + else: + from fla.models.utils import Cache as fla_cache + gla_past_key_values = fla_cache.from_legacy_cache(None) + else: + gla_past_key_values = None + + return past_key_values, gla_past_key_values + + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + **kwargs, + ): + # Omit tokens covered by past_key_values + + if self.config.num_memory_tokens > 0: + attention_mask = torch.cat([torch.ones(input_ids.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1) + + if self.model.use_gla_kv and past_key_values is not None: + past_key_values, gla_past_key_values = past_key_values + else: + gla_past_key_values = None + + if past_key_values is not None: + # the cache may be in the stardard format (e.g. in contrastive search), convert to Jamba's format if needed + if isinstance(past_key_values, Tuple): + if past_key_values[self.model._mamba_layer_index][0].shape[2] > 1: + past_key_values = self._convert_to_jamba_cache(past_key_values) + + if isinstance(past_key_values, Cache): + # if not isinstance(past_key_values, HybridMambaAttentionDynamicCache): + # past_key_values = HybridMambaAttentionDynamicCache.from_legacy_cache( + # past_key_values.to_legacy_cache() + # ) + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + + past_length = cache_length + + else: + cache_length = past_length = past_key_values[self.model._attn_layer_index][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif self.config.num_memory_tokens <= 0 and past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + + elif self.config.num_memory_tokens > 0 and past_length < input_ids.shape[1] + self.config.num_memory_tokens: + new_query_id = past_length - self.config.num_memory_tokens + input_ids = input_ids[:, new_query_id:] + + if self.config.sliding_window is not None and (self.config.global_attn_idx is None or len(self.config.global_attn_idx) == 0): + input_ids = input_ids[:, -1:] + + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device, layer_type=self.config.layer_type + ) + + if self.model.use_gla_kv: + if self.config.layer_types is not None: #todo: This is not accurate + if 'm_orig' in self.config.layer_types or 'mamba_orig' in self.config.layer_types: + from mamba_ssm.utils.generation import InferenceParams + gla_past_key_values = InferenceParams(max_seqlen=8000, max_batch_size=1) ## Note: this is hard-coded; should be updated for general cases + else: + from fla.models.utils import Cache as fla_cache + gla_past_key_values = fla_cache.from_legacy_cache(None) + else: + gla_past_key_values = None + + # print(attention_mask.shape) + # print(attention_mask) + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values.get_seq_length() > 0: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # print(position_ids) + # input() + + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + "calc_logits_for_entire_prompt": self.config.calc_logits_for_entire_prompt, + "gla_past_key_values": gla_past_key_values, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + + +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->Jamba, MIXTRAL->JAMBA +class JambaForSequenceClassification(JambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = JambaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + + + + +def compute_average_attention_entropy(attention_maps_list): + + def entropy(p, unlogit=False): + """Compute the entropy of a probability distribution""" + exponent = 2 + if unlogit: + p = torch.pow(p, exponent) + plogp = p * torch.log(p) + # print("p", p[0,0,:10,:10]) + # print("plogp", plogp[0,0,:10,:10]) + plogp[p == 0] = 0 + # print("plogp", plogp[0,0,:10,:10]) + # assert 1==0 + res = -plogp.sum(dim=-1).mean(dim=-1) + # assert 1==0, res[0,0,:10] + return res + + # Find the minimum sequence length across all prompts + min_seq_length = min(maps[0].size(-1) for maps in attention_maps_list) + + # Initialize a list to store the clipped and averaged attention maps + avg_attention_entropy = [] + + layerwise_head_variance = [] + + # Get the number of layers from the first prompt's attention maps + num_layers = len(attention_maps_list[0]) + + # assert 1==0, attention_maps_list[0][0][0,0,:10,:10] + + for layer in range(num_layers): + # Collect all attention maps for this layer + layer_maps = [maps[layer][:, :, :min_seq_length, :min_seq_length] for maps in attention_maps_list] + + # Stack the tensors along a new dimension + # stacked_maps = torch.stack(layer_maps) + layer_entropy = torch.concat([entropy(x, unlogit=False) for x in layer_maps], dim=0) + + # Compute the average along the new dimension + layer_entropy_avg = layer_entropy.mean(dim=0, keepdim=True) + + head_variance = torch.var(layer_entropy, dim=1, unbiased=False) + head_variance = head_variance.mean(dim=0, keepdim=True) + + avg_attention_entropy.append(layer_entropy_avg) + + layerwise_head_variance.append(head_variance) + + avg_attention_entropy = torch.cat(avg_attention_entropy, dim=0) + layerwise_head_variance = torch.cat(layerwise_head_variance, dim=0) + + return avg_attention_entropy.cpu().numpy(), layerwise_head_variance.cpu().numpy() \ No newline at end of file