| """PyTorch Sarvam MoE model.""" |
|
|
| import math |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.modeling_attn_mask_utils import ( |
| AttentionMaskConverter, |
| _prepare_4d_attention_mask, |
| _prepare_4d_causal_attention_mask, |
| _prepare_4d_causal_attention_mask_for_sdpa, |
| ) |
| from transformers.modeling_outputs import MoeModelOutputWithPast |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS |
| from transformers.utils import ( |
| is_flash_attn_2_available, |
| is_flash_attn_greater_or_equal_2_10, |
| logging, |
| ) |
| from transformers.generation.utils import GenerationMixin |
| from dataclasses import dataclass |
| from transformers.utils import ModelOutput |
|
|
|
|
| if is_flash_attn_2_available(): |
| from flash_attn import flash_attn_func, flash_attn_varlen_func |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
| from .configuration_sarvam_moe import SarvamMoEConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "SarvamMoEConfig" |
|
|
|
|
| @dataclass |
| class SarvamMoECausalLMOutputWithPast(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[Cache] = None |
| hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None |
| attentions: Optional[tuple[torch.FloatTensor, ...]] = None |
| z_loss: Optional[torch.FloatTensor] = None |
| aux_loss: Optional[torch.FloatTensor] = None |
| router_logits: Optional[tuple[torch.FloatTensor]] = None |
|
|
|
|
| class SarvamMoEModelOutputWithPast(MoeModelOutputWithPast): |
| pass |
|
|
|
|
| def _get_unpad_data(attention_mask): |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| max_seqlen_in_batch = seqlens_in_batch.max().item() |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) |
| return indices, cu_seqlens, max_seqlen_in_batch |
|
|
|
|
| def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): |
| return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) |
|
|
|
|
| def _make_causal_mask( |
| input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 |
| ): |
| return AttentionMaskConverter._make_causal_mask( |
| input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length |
| ) |
|
|
|
|
| class SarvamMoERMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
| ALL_LAYERNORM_LAYERS.append(SarvamMoERMSNorm) |
|
|
|
|
| class SarvamMoERotaryEmbedding(nn.Module): |
| def __init__(self, config: SarvamMoEConfig, device=None): |
| super().__init__() |
| self.config = config |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
| rope_scaling = getattr(config, "rope_scaling", None) |
| if rope_scaling is None: |
| self.rope_type = "default" |
| inv_freq, self.attention_scaling = self.compute_default_rope_parameters( |
| config, device |
| ) |
| else: |
| self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) |
| if self.rope_type == "default": |
| inv_freq, self.attention_scaling = self.compute_default_rope_parameters( |
| config, device |
| ) |
| else: |
| rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
| inv_freq, self.attention_scaling = rope_init_fn(config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @staticmethod |
| def compute_default_rope_parameters( |
| config: SarvamMoEConfig, |
| device: Optional[torch.device] = None, |
| seq_len: Optional[int] = None, |
| ) -> Tuple[torch.Tensor, float]: |
| """ |
| Default RoPE parameters (classic rotary embedding). |
| |
| Mirrors HF's default implementation: use `rope_theta`, head_dim and |
| return (inv_freq, attention_scaling). |
| """ |
| base = config.rope_theta |
| dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
| inv_freq = 1.0 / ( |
| base |
| ** ( |
| torch.arange(0, dim, 2, dtype=torch.int64, device=device) |
| .to(dtype=torch.float32) |
| / dim |
| ) |
| ) |
| attention_factor = 1.0 |
| return inv_freq, attention_factor |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| def rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| rotary_dim = cos.shape[-1] |
| q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] |
| k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
| q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) |
| k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) |
| q_embed = torch.cat([q_embed, q_pass], dim=-1) |
| k_embed = torch.cat([k_embed, k_pass], dim=-1) |
| return q_embed, k_embed |
|
|
|
|
| class SarvamMoEMLP(nn.Module): |
| def __init__(self, config: SarvamMoEConfig, intermediate_size: int): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class SarvamMoEGate(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.top_k = config.num_experts_per_tok |
| self.num_experts = config.num_experts |
| self.n_group = config.n_group |
| self.topk_group = config.topk_group |
| self.gating_dim = config.hidden_size |
| self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim))) |
| self.routed_scaling_factor = config.routed_scaling_factor |
| self.score_function = config.score_function |
| |
| |
| self.expert_bias = nn.Parameter( |
| torch.zeros((self.num_experts)), |
| requires_grad=False, |
| ) |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| import torch.nn.init as init |
|
|
| init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
|
|
| def group_limited_topk(self, scores: torch.Tensor): |
| num_tokens, _ = scores.size() |
| group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) |
| group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] |
| group_mask = torch.zeros_like(group_scores) |
| group_mask.scatter_(1, group_idx, 1) |
| score_mask = ( |
| group_mask.unsqueeze(-1) |
| .expand(num_tokens, self.n_group, self.num_experts // self.n_group) |
| .reshape(num_tokens, -1) |
| ) |
| masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) |
| probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1) |
| return probs, top_indices |
|
|
| def forward(self, hidden_states): |
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
| logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) |
| scores = torch.sigmoid(logits.float()).type_as(logits) |
| scores_for_routing = scores + self.expert_bias |
| _, topk_idx = self.group_limited_topk(scores_for_routing) |
| scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits) |
| topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores |
| topk_weight = topk_weight * self.routed_scaling_factor |
| return topk_idx, topk_weight, logits |
|
|
|
|
| class SarvamMoEExperts(nn.ModuleList): |
| def __init__(self, config: SarvamMoEConfig): |
| |
| experts = [ |
| SarvamMoEMLP(config=config, intermediate_size=config.moe_intermediate_size) |
| for _ in range(config.num_experts) |
| ] |
| super().__init__(experts) |
| self.config = config |
| self.num_experts_per_tok = config.num_experts_per_tok |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| top_k_index: torch.LongTensor, |
| top_k_weights: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| hidden_states: (tokens, hidden_size) or (batch * seq, hidden_size) |
| top_k_index: (tokens, top_k) |
| top_k_weights: (tokens, top_k) |
| """ |
| tokens, hidden_dim = hidden_states.shape |
| flat_topk_idx = top_k_index.view(-1) |
|
|
| if self.training: |
| |
| x = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0) |
| y = torch.empty_like(x) |
| for i, expert in enumerate(self): |
| mask = flat_topk_idx == i |
| if mask.any(): |
| y[mask] = expert(x[mask]) |
| y = (y.view(*top_k_weights.shape, -1) * top_k_weights.unsqueeze(-1)).sum(dim=1) |
| return y.to(hidden_states.dtype) |
|
|
| |
| num_experts = len(self) |
| cnts = top_k_index.new_zeros((tokens, num_experts)) |
| cnts.scatter_(1, top_k_index, 1) |
| tokens_per_expert = cnts.sum(dim=0) |
|
|
| idxs = top_k_index.view(-1).argsort() |
| sorted_tokens = hidden_states[idxs // top_k_index.shape[1]] |
|
|
| tokens_per_expert = tokens_per_expert.cpu().numpy().tolist() |
| outputs = [] |
| start_idx = 0 |
| for i, num_tokens in enumerate(tokens_per_expert): |
| end_idx = start_idx + num_tokens |
| if num_tokens == 0: |
| continue |
| expert = self[i] |
| tokens_for_expert = sorted_tokens[start_idx:end_idx] |
| expert_out = expert(tokens_for_expert) |
| outputs.append(expert_out.to(hidden_states.device)) |
| start_idx = end_idx |
|
|
| outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) |
| new_x = torch.empty_like(outs) |
| new_x[idxs] = outs |
|
|
| final_out = ( |
| new_x.view(*top_k_index.shape, -1) |
| .type(top_k_weights.dtype) |
| .mul_(top_k_weights.unsqueeze(dim=-1)) |
| .sum(dim=1) |
| .type(new_x.dtype) |
| ) |
| return final_out |
|
|
|
|
| class SarvamMoESparseMoeBlock(nn.Module): |
| def __init__(self, config: SarvamMoEConfig): |
| super().__init__() |
| self.config = config |
| self.num_experts_per_tok = config.num_experts_per_tok |
|
|
| |
| self.experts = SarvamMoEExperts(config) |
| self.gate = SarvamMoEGate(config) |
|
|
| if config.num_shared_experts is not None: |
| self.shared_experts = SarvamMoEMLP( |
| config=config, |
| intermediate_size=config.moe_intermediate_size * config.num_shared_experts, |
| ) |
|
|
| |
|
|
| def forward(self, hidden_states): |
| identity = hidden_states |
| bsz, seq_len, h = hidden_states.shape |
|
|
| topk_idx, topk_weight, router_logits = self.gate(hidden_states) |
|
|
| |
| flat_hidden = hidden_states.view(-1, h) |
| flat_topk_idx = topk_idx.view(-1, topk_idx.shape[-1]) |
| flat_topk_weight = topk_weight.view(-1, topk_weight.shape[-1]) |
|
|
| y = self.experts(flat_hidden, flat_topk_idx, flat_topk_weight) |
| y = y.view(bsz, seq_len, h) |
|
|
| if self.config.num_shared_experts is not None: |
| y = y + self.shared_experts(identity) |
|
|
| |
| router_info = ( |
| router_logits.view(bsz, seq_len, -1), |
| topk_idx.view(bsz, seq_len, -1), |
| ) |
| return y, router_info |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| class SarvamMoEAttention(nn.Module): |
| is_causal = True |
| def __init__(self, config: SarvamMoEConfig, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| if layer_idx is None: |
| logger.warning_once( |
| f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " |
| "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| "when creating this class." |
| ) |
| self.attention_dropout = config.attention_dropout |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = config.head_dim or self.hidden_size // self.num_heads |
| partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 |
| self.rope_dim = int(self.head_dim * partial_rotary_factor) |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.max_position_embeddings = config.max_position_embeddings |
| self.rope_theta = config.rope_theta |
| self.query_key_value = nn.Linear( |
| self.hidden_size, |
| (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, |
| bias=config.use_qkv_bias, |
| ) |
| if self.config.use_qk_norm: |
| self.query_layernorm = SarvamMoERMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.key_layernorm = SarvamMoERMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) |
| self.scaling = self.head_dim**-0.5 |
|
|
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ): |
| bsz, q_len, _ = hidden_states.size() |
| qkv = self.query_key_value(hidden_states) |
| qkv = qkv.view( |
| bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim |
| ) |
| query_states, key_states, value_states = qkv.split( |
| [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], |
| dim=-2, |
| ) |
| query_states = query_states.transpose(1, 2).contiguous() |
| key_states = key_states.transpose(1, 2).contiguous() |
| value_states = value_states.transpose(1, 2).contiguous() |
| if self.config.use_qk_norm: |
| query_states = self.query_layernorm(query_states) |
| key_states = self.key_layernorm(key_states) |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb( |
| query_states, key_states, cos, sin |
| ) |
| if past_key_value is not None: |
| if self.layer_idx is None: |
| raise ValueError( |
| "When using cache, SarvamMoEAttention must be initialized with layer_idx." |
| ) |
| cache_kwargs = {"sin": sin, "cos": cos} |
| key_states, value_states = past_key_value.update( |
| key_states, value_states, self.layer_idx, cache_kwargs |
| ) |
| |
| if self.config._attn_implementation == "vllm": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
| |
| if attn_output.dim() == 4: |
| |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(bsz, q_len, -1) |
| elif attn_output.dim() == 3: |
| if attn_output.shape[0] != bsz or attn_output.shape[1] != q_len: |
| raise ValueError( |
| f"Unexpected vLLM attention output shape {attn_output.shape}, " |
| f"expected (bsz={bsz}, q_len={q_len}, hidden=*)" |
| ) |
| elif attn_output.dim() == 2: |
| attn_output = attn_output.view(bsz, q_len, -1) |
| else: |
| raise ValueError( |
| f"Unsupported vLLM attention output rank {attn_output.dim()} " |
| f"with shape {attn_output.shape}" |
| ) |
| attn_output = self.dense(attn_output) |
| if not output_attentions: |
| attn_weights = None |
| return attn_output, attn_weights, past_key_value |
|
|
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
| kv_seq_len = key_states.shape[-2] |
| if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
| f" {attn_weights.size()}" |
| ) |
| if attention_mask is not None: |
| if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
| ) |
| attn_weights = attn_weights + attention_mask |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(bsz, q_len, -1) |
| attn_output = self.dense(attn_output) |
| if not output_attentions: |
| attn_weights = None |
| return attn_output, attn_weights, past_key_value |
|
|
|
|
| class SarvamMoEFlashAttention2(SarvamMoEAttention): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| output_attentions = False |
| bsz, q_len, _ = hidden_states.size() |
| qkv = self.query_key_value(hidden_states) |
| qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) |
| query_states, key_states, value_states = qkv.split( |
| [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 |
| ) |
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
| if self.config.use_qk_norm: |
| query_states = self.query_layernorm(query_states) |
| key_states = self.key_layernorm(key_states) |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| if past_key_value is not None: |
| cache_kwargs = {"sin": sin, "cos": cos} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
| dropout_rate = self.attention_dropout if self.training else 0.0 |
| input_dtype = query_states.dtype |
| if input_dtype == torch.float32: |
| if hasattr(self.config, "_pre_quantization_dtype"): |
| target_dtype = self.config._pre_quantization_dtype |
| elif torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| else: |
| target_dtype = self.query_key_value.weight.dtype |
| logger.warning_once( |
| f"The input hidden states seems to be silently casted in float32, this might be related to" |
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
| f" {target_dtype}." |
| ) |
| query_states = query_states.to(target_dtype) |
| key_states = key_states.to(target_dtype) |
| value_states = value_states.to(target_dtype) |
| attn_output = self._flash_attention_forward( |
| query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate |
| ) |
| attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() |
| attn_output = self.dense(attn_output) |
| if not output_attentions: |
| attn_weights = None |
| return attn_output, attn_weights, past_key_value |
|
|
| def _flash_attention_forward( |
| self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None |
| ): |
| if not self._flash_attn_uses_top_left_mask: |
| causal = self.is_causal |
| else: |
| causal = self.is_causal and query_length != 1 |
| if attention_mask is not None: |
| batch_size = query_states.shape[0] |
| query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
| query_states, key_states, value_states, attention_mask, query_length |
| ) |
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
| attn_output_unpad = flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, |
| max_seqlen_k=max_seqlen_in_batch_k, |
| dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
| attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
| else: |
| attn_output = flash_attn_func( |
| query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal |
| ) |
| return attn_output |
|
|
| def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): |
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
| key_layer = index_first_axis( |
| key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
| ) |
| value_layer = index_first_axis( |
| value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
| ) |
| if query_length == kv_seq_len: |
| query_layer = index_first_axis( |
| query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k |
| ) |
| cu_seqlens_q = cu_seqlens_k |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| indices_q = indices_k |
| elif query_length == 1: |
| max_seqlen_in_batch_q = 1 |
| cu_seqlens_q = torch.arange( |
| batch_size + 1, dtype=torch.int32, device=query_layer.device |
| ) |
| indices_q = cu_seqlens_q[:-1] |
| query_layer = query_layer.squeeze(1) |
| else: |
| attention_mask = attention_mask[:, -query_length:] |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) |
| return ( |
| query_layer, |
| key_layer, |
| value_layer, |
| indices_q, |
| (cu_seqlens_q, cu_seqlens_k), |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| ) |
|
|
|
|
| class SarvamMoESdpaAttention(SarvamMoEAttention): |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if output_attentions: |
| return super().forward( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
| bsz, q_len, _ = hidden_states.size() |
| qkv = self.query_key_value(hidden_states) |
| qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) |
| query_states, key_states, value_states = qkv.split( |
| [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 |
| ) |
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
| if self.config.use_qk_norm: |
| query_states = self.query_layernorm(query_states) |
| key_states = self.key_layernorm(key_states) |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| if past_key_value is not None: |
| cache_kwargs = {"sin": sin, "cos": cos} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
| if attention_mask is not None: |
| kv_seq_len = key_states.shape[-2] |
| if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
| ) |
| if query_states.device.type == "cuda" and attention_mask is not None: |
| query_states = query_states.contiguous() |
| key_states = key_states.contiguous() |
| value_states = value_states.contiguous() |
| attn_output = torch.nn.functional.scaled_dot_product_attention( |
| query_states, |
| key_states, |
| value_states, |
| attn_mask=attention_mask, |
| dropout_p=self.attention_dropout if self.training else 0.0, |
| is_causal=self.is_causal and attention_mask is None and q_len > 1, |
| ) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(bsz, q_len, -1) |
| attn_output = self.dense(attn_output) |
| return attn_output, None, past_key_value |
|
|
|
|
| ATTENTION_CLASSES = { |
| "eager": SarvamMoEAttention, |
| "flash_attention_2": SarvamMoEFlashAttention2, |
| "sdpa": SarvamMoESdpaAttention, |
| "vllm": SarvamMoEAttention, |
| } |
|
|
|
|
| class SarvamMoEDecoderLayer(nn.Module): |
| def __init__(self, config: SarvamMoEConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.attention = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
| self.mlp = ( |
| SarvamMoESparseMoeBlock(config) |
| if (config.num_experts is not None and layer_idx >= config.first_k_dense_replace) |
| else SarvamMoEMLP(config=config, intermediate_size=config.intermediate_size) |
| ) |
| self.input_layernorm = SarvamMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SarvamMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| output_router_logits: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, self_attn_weights, present_key_value = self.attention( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| position_embeddings=position_embeddings, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| if isinstance(hidden_states, tuple): |
| hidden_states, router_logits = hidden_states |
| else: |
| router_logits = None |
| hidden_states = residual + hidden_states.to(residual.device) |
| outputs = (hidden_states,) |
| if output_attentions: |
| outputs += (self_attn_weights,) |
| if use_cache: |
| outputs += (present_key_value,) |
| if output_router_logits: |
| outputs += (router_logits,) |
| return outputs |
|
|
| class SarvamMoEPreTrainedModel(PreTrainedModel): |
| config_class = SarvamMoEConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["SarvamMoEDecoderLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
|
| class SarvamMoEModel(SarvamMoEPreTrainedModel): |
| _supports_attention_backend = True |
| def __init__(self, config: SarvamMoEConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
| self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = [] |
| for layer_idx in range(config.num_hidden_layers): |
| self.layers.append(SarvamMoEDecoderLayer(config, layer_idx)) |
| self.layers = nn.ModuleList(self.layers) |
| self._use_sdpa = config._attn_implementation == "sdpa" |
| self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
| self.norm = SarvamMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = SarvamMoERotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.word_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.word_embeddings = value |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_router_logits: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[Tuple, SarvamMoEModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| output_router_logits = ( |
| output_router_logits if output_router_logits is not None else self.config.output_router_logits |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| batch_size, seq_length = input_ids.shape[:2] |
| elif inputs_embeds is not None: |
| batch_size, seq_length = inputs_embeds.shape[:2] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." |
| ) |
| use_cache = False |
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache() |
| if inputs_embeds is None: |
| inputs_embeds = self.word_embeddings(input_ids) |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| if position_ids is None: |
| position_ids = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
| position_ids = position_ids.unsqueeze(0) |
| if self._use_flash_attention_2: |
| attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
| elif self._use_sdpa and not output_attentions: |
| attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
| attention_mask, |
| (batch_size, seq_length), |
| inputs_embeds, |
| past_seen_tokens, |
| ) |
| else: |
| attention_mask = _prepare_4d_causal_attention_mask( |
| attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens |
| ) |
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| all_router_logits = () if output_router_logits else None |
| next_decoder_cache = None |
| layers = self.layers |
| for decoder_layer in layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| output_router_logits, |
| use_cache, |
| position_embeddings, |
| **kwargs, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| output_router_logits=output_router_logits, |
| use_cache=use_cache, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = layer_outputs[0] |
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
| if output_router_logits and layer_outputs[-1] is not None: |
| all_router_logits += (layer_outputs[-1],) |
| hidden_states = self.norm(hidden_states) |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
| next_cache = None |
| if use_cache: |
| next_cache = next_decoder_cache |
| if not return_dict: |
| return tuple( |
| v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None |
| ) |
| return SarvamMoEModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| router_logits=all_router_logits, |
| ) |
|
|
|
|
| class SarvamMoEForCausalLM(SarvamMoEPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config: SarvamMoEConfig): |
| super().__init__(config) |
| self.model = SarvamMoEModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.word_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.model.word_embeddings = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_router_logits: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[Tuple, SarvamMoEModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| output_router_logits = ( |
| output_router_logits if output_router_logits is not None else self.config.output_router_logits |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| output_router_logits=output_router_logits, |
| return_dict=return_dict, |
| **kwargs, |
| ) |
| loss = None |
| aux_loss = None |
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
| if labels is not None: |
| loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs) |
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| if output_router_logits: |
| output = (aux_loss,) + output |
| return (loss,) + output if loss is not None else output |
| return SarvamMoECausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| aux_loss=aux_loss, |
| router_logits=outputs.router_logits, |
| ) |