sarvam-30b-FP8-Dynamic / modeling_sarvam_moe.py
krishnateja95's picture
Upload folder using huggingface_hub
588f8ef verified
"""PyTorch Sarvam MoE model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from transformers.generation.utils import GenerationMixin
from dataclasses import dataclass
from transformers.utils import ModelOutput
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
from .configuration_sarvam_moe import SarvamMoEConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SarvamMoEConfig"
@dataclass
class SarvamMoECausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
z_loss: Optional[torch.FloatTensor] = None
aux_loss: Optional[torch.FloatTensor] = None
router_logits: Optional[tuple[torch.FloatTensor]] = None
class SarvamMoEModelOutputWithPast(MoeModelOutputWithPast):
pass
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
return AttentionMaskConverter._make_causal_mask(
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
)
class SarvamMoERMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
ALL_LAYERNORM_LAYERS.append(SarvamMoERMSNorm)
class SarvamMoERotaryEmbedding(nn.Module):
def __init__(self, config: SarvamMoEConfig, device=None):
super().__init__()
self.config = config
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
self.rope_type = "default"
inv_freq, self.attention_scaling = self.compute_default_rope_parameters(
config, device
)
else:
self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
if self.rope_type == "default":
inv_freq, self.attention_scaling = self.compute_default_rope_parameters(
config, device
)
else:
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@staticmethod
def compute_default_rope_parameters(
config: SarvamMoEConfig,
device: Optional[torch.device] = None,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, float]:
"""
Default RoPE parameters (classic rotary embedding).
Mirrors HF's default implementation: use `rope_theta`, head_dim and
return (inv_freq, attention_scaling).
"""
base = config.rope_theta
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
inv_freq = 1.0 / (
base
** (
torch.arange(0, dim, 2, dtype=torch.int64, device=device)
.to(dtype=torch.float32)
/ dim
)
)
attention_factor = 1.0
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
class SarvamMoEMLP(nn.Module):
def __init__(self, config: SarvamMoEConfig, intermediate_size: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class SarvamMoEGate(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.n_group = config.n_group
self.topk_group = config.topk_group
self.gating_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))
self.routed_scaling_factor = config.routed_scaling_factor
self.score_function = config.score_function
# Ideally, we should register the expert_bias as a buffer, but vllm complains about it.
# self.register_buffer("expert_bias", torch.zeros((self.num_experts)))
self.expert_bias = nn.Parameter(
torch.zeros((self.num_experts)),
requires_grad=False,
)
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def group_limited_topk(self, scores: torch.Tensor):
num_tokens, _ = scores.size()
group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_tokens, self.n_group, self.num_experts // self.n_group)
.reshape(num_tokens, -1)
)
masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1)
return probs, top_indices
def forward(self, hidden_states):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
scores = torch.sigmoid(logits.float()).type_as(logits)
scores_for_routing = scores + self.expert_bias
_, topk_idx = self.group_limited_topk(scores_for_routing)
scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
topk_weight = topk_weight * self.routed_scaling_factor
return topk_idx, topk_weight, logits
class SarvamMoEExperts(nn.ModuleList):
def __init__(self, config: SarvamMoEConfig):
# one MLP per expert
experts = [
SarvamMoEMLP(config=config, intermediate_size=config.moe_intermediate_size)
for _ in range(config.num_experts)
]
super().__init__(experts)
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.LongTensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
hidden_states: (tokens, hidden_size) or (batch * seq, hidden_size)
top_k_index: (tokens, top_k)
top_k_weights: (tokens, top_k)
"""
tokens, hidden_dim = hidden_states.shape
flat_topk_idx = top_k_index.view(-1)
if self.training:
# training path: same as your previous logic
x = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
y = torch.empty_like(x)
for i, expert in enumerate(self):
mask = flat_topk_idx == i
if mask.any():
y[mask] = expert(x[mask])
y = (y.view(*top_k_weights.shape, -1) * top_k_weights.unsqueeze(-1)).sum(dim=1)
return y.to(hidden_states.dtype)
# inference path: previous moe_infer logic
num_experts = len(self)
cnts = top_k_index.new_zeros((tokens, num_experts))
cnts.scatter_(1, top_k_index, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = top_k_index.view(-1).argsort()
sorted_tokens = hidden_states[idxs // top_k_index.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy().tolist()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self[i]
tokens_for_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert(tokens_for_expert)
outputs.append(expert_out.to(hidden_states.device))
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*top_k_index.shape, -1)
.type(top_k_weights.dtype)
.mul_(top_k_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
class SarvamMoESparseMoeBlock(nn.Module):
def __init__(self, config: SarvamMoEConfig):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
# use the new experts container
self.experts = SarvamMoEExperts(config)
self.gate = SarvamMoEGate(config)
if config.num_shared_experts is not None:
self.shared_experts = SarvamMoEMLP(
config=config,
intermediate_size=config.moe_intermediate_size * config.num_shared_experts,
)
# _setup_experts no longer needed
def forward(self, hidden_states):
identity = hidden_states
bsz, seq_len, h = hidden_states.shape
topk_idx, topk_weight, router_logits = self.gate(hidden_states)
# flatten batch+seq for experts
flat_hidden = hidden_states.view(-1, h)
flat_topk_idx = topk_idx.view(-1, topk_idx.shape[-1])
flat_topk_weight = topk_weight.view(-1, topk_weight.shape[-1])
y = self.experts(flat_hidden, flat_topk_idx, flat_topk_weight)
y = y.view(bsz, seq_len, h)
if self.config.num_shared_experts is not None:
y = y + self.shared_experts(identity)
# router logits shape: (bsz, seq_len, num_experts)
router_info = (
router_logits.view(bsz, seq_len, -1),
topk_idx.view(bsz, seq_len, -1),
)
return y, router_info
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class SarvamMoEAttention(nn.Module):
is_causal = True # vLLM / Transformers backend critical flag
def __init__(self, config: SarvamMoEConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim or self.hidden_size // self.num_heads
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
self.rope_dim = int(self.head_dim * partial_rotary_factor)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.query_key_value = nn.Linear(
self.hidden_size,
(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
bias=config.use_qkv_bias,
)
if self.config.use_qk_norm:
self.query_layernorm = SarvamMoERMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_layernorm = SarvamMoERMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
self.scaling = self.head_dim**-0.5
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(
bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim
)
query_states, key_states, value_states = qkv.split(
[self.num_heads, self.num_key_value_heads, self.num_key_value_heads],
dim=-2,
)
query_states = query_states.transpose(1, 2).contiguous()
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
if self.config.use_qk_norm:
query_states = self.query_layernorm(query_states)
key_states = self.key_layernorm(key_states)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
"When using cache, SarvamMoEAttention must be initialized with layer_idx."
)
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# NOTE: vLLM will set config._attn_implementation = "vllm"
if self.config._attn_implementation == "vllm":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
# vLLM backend may return [B, L, hidden] or [B*L, hidden]
if attn_output.dim() == 4:
# [B, H, L, Dh] -> [B, L, hidden]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
elif attn_output.dim() == 3:
if attn_output.shape[0] != bsz or attn_output.shape[1] != q_len:
raise ValueError(
f"Unexpected vLLM attention output shape {attn_output.shape}, "
f"expected (bsz={bsz}, q_len={q_len}, hidden=*)"
)
elif attn_output.dim() == 2:
attn_output = attn_output.view(bsz, q_len, -1)
else:
raise ValueError(
f"Unsupported vLLM attention output rank {attn_output.dim()} "
f"with shape {attn_output.shape}"
)
attn_output = self.dense(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
kv_seq_len = key_states.shape[-2]
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.dense(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class SarvamMoEFlashAttention2(SarvamMoEAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split(
[self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if self.config.use_qk_norm:
query_states = self.query_layernorm(query_states)
key_states = self.key_layernorm(key_states)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.query_key_value.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.dense(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and query_length != 1
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
)
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class SarvamMoESdpaAttention(SarvamMoEAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split(
[self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if self.config.use_qk_norm:
query_states = self.query_layernorm(query_states)
key_states = self.key_layernorm(key_states)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None:
kv_seq_len = key_states.shape[-2]
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.dense(attn_output)
return attn_output, None, past_key_value
ATTENTION_CLASSES = {
"eager": SarvamMoEAttention,
"flash_attention_2": SarvamMoEFlashAttention2,
"sdpa": SarvamMoESdpaAttention,
"vllm": SarvamMoEAttention,
}
class SarvamMoEDecoderLayer(nn.Module):
def __init__(self, config: SarvamMoEConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attention = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = (
SarvamMoESparseMoeBlock(config)
if (config.num_experts is not None and layer_idx >= config.first_k_dense_replace)
else SarvamMoEMLP(config=config, intermediate_size=config.intermediate_size)
)
self.input_layernorm = SarvamMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SarvamMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
use_cache=use_cache,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states, router_logits = hidden_states
else:
router_logits = None
hidden_states = residual + hidden_states.to(residual.device)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
if output_router_logits:
outputs += (router_logits,)
return outputs
class SarvamMoEPreTrainedModel(PreTrainedModel):
config_class = SarvamMoEConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["SarvamMoEDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class SarvamMoEModel(SarvamMoEPreTrainedModel):
_supports_attention_backend = True
def __init__(self, config: SarvamMoEConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = []
for layer_idx in range(config.num_hidden_layers):
self.layers.append(SarvamMoEDecoderLayer(config, layer_idx))
self.layers = nn.ModuleList(self.layers)
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = SarvamMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = SarvamMoERotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.word_embeddings
def set_input_embeddings(self, value):
self.word_embeddings = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, SarvamMoEModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
)
use_cache = False
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
if position_ids is None:
position_ids = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
position_ids = position_ids.unsqueeze(0)
if self._use_flash_attention_2:
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_seen_tokens,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
layers = self.layers
for decoder_layer in layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
position_embeddings,
**kwargs,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits and layer_outputs[-1] is not None:
all_router_logits += (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache
if not return_dict:
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None
)
return SarvamMoEModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
class SarvamMoEForCausalLM(SarvamMoEPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: SarvamMoEConfig):
super().__init__(config)
self.model = SarvamMoEModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.word_embeddings
def set_input_embeddings(self, value):
self.model.word_embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, SarvamMoEModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
**kwargs,
)
loss = None
aux_loss = None
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return SarvamMoECausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
aux_loss=aux_loss,
router_logits=outputs.router_logits,
)