| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ PyTorch DeepSeek model with ScatterMoE optimization.""" |
| import math |
| import warnings |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.modeling_attn_mask_utils import ( |
| AttentionMaskConverter, |
| _prepare_4d_attention_mask, |
| _prepare_4d_causal_attention_mask, |
| ) |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| SequenceClassifierOutputWithPast, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.pytorch_utils import ( |
| ALL_LAYERNORM_LAYERS, |
| is_torch_greater_or_equal_than_1_13, |
| ) |
| from transformers.utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| is_flash_attn_greater_or_equal_2_10, |
| logging, |
| replace_return_docstrings, |
| ) |
| from transformers.utils.import_utils import is_torch_fx_available |
| from .configuration_scatterbrain_moonlight import ScatterbrainMoonlightConfig |
| import torch.distributed as dist |
| import numpy as np |
|
|
| |
| try: |
| from scattermoe import flatten_sort_count, parallel_linear |
| SCATTERMOE_AVAILABLE = True |
| except ImportError: |
| SCATTERMOE_AVAILABLE = False |
| warnings.warn("ScatterMoE not available. Install with: pip install scattermoe") |
|
|
| if is_flash_attn_2_available(): |
| from flash_attn import flash_attn_func, flash_attn_varlen_func |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
|
|
| |
| if is_torch_fx_available(): |
| if not is_torch_greater_or_equal_than_1_13: |
| import torch.fx |
|
|
| _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "ScatterbrainMoonlightConfig" |
|
|
|
|
| def _get_cache_length(cache, seq_length, layer_idx=None): |
| """Helper to get cache length, compatible with both old and new transformers API.""" |
| if hasattr(cache, 'get_usable_length'): |
| if layer_idx is not None: |
| return cache.get_usable_length(seq_length, layer_idx) |
| return cache.get_usable_length(seq_length) |
| elif hasattr(cache, 'get_seq_length'): |
| if layer_idx is not None: |
| return cache.get_seq_length(layer_idx) |
| return cache.get_seq_length() |
| return 0 |
|
|
|
|
| class ExpandedDynamicCache(Cache): |
| """Dynamic cache that supports arbitrary layer indices for virtual layers.""" |
|
|
| def __init__(self, num_hidden_layers: int = None): |
| self._num_hidden_layers = num_hidden_layers or 128 |
| self.key_cache: List[Optional[torch.Tensor]] = [] |
| self.value_cache: List[Optional[torch.Tensor]] = [] |
| self._seen_tokens = 0 |
| self.layers = None |
| self.layer_class_to_replicate = None |
| self.offloading = False |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[dict] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| while len(self.key_cache) <= layer_idx: |
| self.key_cache.append(None) |
| self.value_cache.append(None) |
|
|
| if self.key_cache[layer_idx] is None: |
| self.key_cache[layer_idx] = key_states |
| self.value_cache[layer_idx] = value_states |
| if layer_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
| else: |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
| if layer_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
|
|
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
| def get_seq_length(self, layer_idx: int = 0) -> int: |
| if layer_idx < len(self.key_cache) and self.key_cache[layer_idx] is not None: |
| return self.key_cache[layer_idx].shape[-2] |
| return 0 |
|
|
| def get_max_cache_shape(self) -> Optional[int]: |
| return None |
|
|
| def get_max_length(self) -> Optional[int]: |
| return None |
|
|
| @property |
| def seen_tokens(self) -> int: |
| return self._seen_tokens |
|
|
| def __len__(self) -> int: |
| return len(self.key_cache) |
|
|
| def __iter__(self): |
| for layer_idx in range(len(self)): |
| if layer_idx < len(self.key_cache) and self.key_cache[layer_idx] is not None: |
| yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
| else: |
| yield (None, None) |
|
|
| def __getitem__(self, layer_idx: int): |
| if layer_idx < len(self.key_cache) and self.key_cache[layer_idx] is not None: |
| return (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
| return (None, None) |
|
|
| def to_legacy_cache(self): |
| legacy_cache = () |
| for layer_idx in range(len(self)): |
| if layer_idx < len(self.key_cache) and self.key_cache[layer_idx] is not None: |
| legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) |
| else: |
| legacy_cache += ((None, None),) |
| return legacy_cache |
|
|
| def reorder_cache(self, beam_idx: torch.LongTensor): |
| for layer_idx in range(len(self.key_cache)): |
| if self.key_cache[layer_idx] is not None: |
| device = self.key_cache[layer_idx].device |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
|
| def _get_unpad_data(attention_mask): |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| max_seqlen_in_batch = seqlens_in_batch.max().item() |
| cu_seqlens = F.pad( |
| torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) |
| ) |
| return ( |
| indices, |
| cu_seqlens, |
| max_seqlen_in_batch, |
| ) |
|
|
|
|
| class ScatterbrainMoonlightRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
| ALL_LAYERNORM_LAYERS.append(ScatterbrainMoonlightRMSNorm) |
|
|
|
|
| class ScatterbrainMoonlightRotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): |
| super().__init__() |
|
|
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
| inv_freq = 1.0 / ( |
| self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) |
| ) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| self._set_cos_sin_cache( |
| seq_len=max_position_embeddings, |
| device=self.inv_freq.device, |
| dtype=torch.get_default_dtype(), |
| ) |
| self.max_seq_len_cached = None |
|
|
| def _set_cos_sin_cache(self, seq_len, device, dtype): |
| self.max_seq_len_cached = seq_len |
| t = torch.arange( |
| self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype |
| ) |
|
|
| freqs = torch.outer(t, self.inv_freq.to(t.device)) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) |
| self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) |
|
|
| def forward(self, x, seq_len=None): |
| if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: |
| self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
| return ( |
| self.cos_cached[:seq_len].to(dtype=x.dtype), |
| self.sin_cached[:seq_len].to(dtype=x.dtype), |
| ) |
|
|
|
|
| class ScatterbrainMoonlightLinearScalingRotaryEmbedding(ScatterbrainMoonlightRotaryEmbedding): |
|
|
| def __init__( |
| self, |
| dim, |
| max_position_embeddings=2048, |
| base=10000, |
| device=None, |
| scaling_factor=1.0, |
| ): |
| self.scaling_factor = scaling_factor |
| super().__init__(dim, max_position_embeddings, base, device) |
|
|
| def _set_cos_sin_cache(self, seq_len, device, dtype): |
| self.max_seq_len_cached = seq_len |
| t = torch.arange( |
| self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype |
| ) |
| t = t / self.scaling_factor |
|
|
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) |
| self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) |
|
|
|
|
| class ScatterbrainMoonlightDynamicNTKScalingRotaryEmbedding(ScatterbrainMoonlightRotaryEmbedding): |
|
|
| def __init__( |
| self, |
| dim, |
| max_position_embeddings=2048, |
| base=10000, |
| device=None, |
| scaling_factor=1.0, |
| ): |
| self.scaling_factor = scaling_factor |
| super().__init__(dim, max_position_embeddings, base, device) |
|
|
| def _set_cos_sin_cache(self, seq_len, device, dtype): |
| self.max_seq_len_cached = seq_len |
|
|
| if seq_len > self.max_position_embeddings: |
| base = self.base * ( |
| (self.scaling_factor * seq_len / self.max_position_embeddings) |
| - (self.scaling_factor - 1) |
| ) ** (self.dim / (self.dim - 2)) |
| inv_freq = 1.0 / ( |
| base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) |
| ) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| t = torch.arange( |
| self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype |
| ) |
|
|
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) |
| self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) |
|
|
|
|
| def yarn_find_correction_dim( |
| num_rotations, dim, base=10000, max_position_embeddings=2048 |
| ): |
| return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( |
| 2 * math.log(base) |
| ) |
|
|
|
|
| def yarn_find_correction_range( |
| low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 |
| ): |
| low = math.floor( |
| yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) |
| ) |
| high = math.ceil( |
| yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) |
| ) |
| return max(low, 0), min(high, dim - 1) |
|
|
|
|
| def yarn_get_mscale(scale=1, mscale=1): |
| if scale <= 1: |
| return 1.0 |
| return 0.1 * mscale * math.log(scale) + 1.0 |
|
|
|
|
| def yarn_linear_ramp_mask(min, max, dim): |
| if min == max: |
| max += 0.001 |
|
|
| linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) |
| ramp_func = torch.clamp(linear_func, 0, 1) |
| return ramp_func |
|
|
|
|
| class ScatterbrainMoonlightYarnRotaryEmbedding(ScatterbrainMoonlightRotaryEmbedding): |
|
|
| def __init__( |
| self, |
| dim, |
| max_position_embeddings=2048, |
| base=10000, |
| device=None, |
| scaling_factor=1.0, |
| original_max_position_embeddings=4096, |
| beta_fast=32, |
| beta_slow=1, |
| mscale=1, |
| mscale_all_dim=0, |
| ): |
| self.scaling_factor = scaling_factor |
| self.original_max_position_embeddings = original_max_position_embeddings |
| self.beta_fast = beta_fast |
| self.beta_slow = beta_slow |
| self.mscale = mscale |
| self.mscale_all_dim = mscale_all_dim |
| super().__init__(dim, max_position_embeddings, base, device) |
|
|
| def _set_cos_sin_cache(self, seq_len, device, dtype): |
| self.max_seq_len_cached = seq_len |
| dim = self.dim |
|
|
| freq_extra = 1.0 / ( |
| self.base |
| ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) |
| ) |
| freq_inter = 1.0 / ( |
| self.scaling_factor |
| * self.base |
| ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) |
| ) |
|
|
| low, high = yarn_find_correction_range( |
| self.beta_fast, |
| self.beta_slow, |
| dim, |
| self.base, |
| self.original_max_position_embeddings, |
| ) |
| inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( |
| device=device, dtype=torch.float32 |
| ) |
| inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| t = torch.arange(seq_len, device=device, dtype=torch.float32) |
|
|
| freqs = torch.outer(t, inv_freq) |
|
|
| _mscale = float( |
| yarn_get_mscale(self.scaling_factor, self.mscale) |
| / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) |
| ) |
|
|
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer( |
| "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False |
| ) |
| self.register_buffer( |
| "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False |
| ) |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): |
| cos = cos[position_ids].unsqueeze(unsqueeze_dim) |
| sin = sin[position_ids].unsqueeze(unsqueeze_dim) |
|
|
| b, h, s, d = q.shape |
| q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) |
|
|
| b, h, s, d = k.shape |
| k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) |
|
|
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class ScatterbrainMoonlightMLP(nn.Module): |
| def __init__(self, config, hidden_size=None, intermediate_size=None): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size if hidden_size is None else hidden_size |
| self.intermediate_size = ( |
| config.intermediate_size if intermediate_size is None else intermediate_size |
| ) |
|
|
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| class ScatterbrainMoonlightMoEGate(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.top_k = config.num_experts_per_tok |
| self.n_routed_experts = config.n_routed_experts |
| self.routed_scaling_factor = config.routed_scaling_factor |
| self.scoring_func = config.scoring_func |
| self.seq_aux = config.seq_aux |
| self.topk_method = config.topk_method |
| self.n_group = config.n_group |
| self.topk_group = config.topk_group |
|
|
| self.norm_topk_prob = config.norm_topk_prob |
| self.gating_dim = config.hidden_size |
| self.weight = nn.Parameter( |
| torch.empty((self.n_routed_experts, self.gating_dim)) |
| ) |
| if self.topk_method == "noaux_tc": |
| self.e_score_correction_bias = nn.Parameter( |
| torch.empty((self.n_routed_experts)) |
| ) |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| import torch.nn.init as init |
|
|
| init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
|
|
| def forward(self, hidden_states): |
| bsz, seq_len, h = hidden_states.shape |
| hidden_states = hidden_states.view(-1, h) |
| logits = F.linear( |
| hidden_states.type(torch.float32), self.weight.type(torch.float32), None |
| ) |
| if self.scoring_func == "sigmoid": |
| scores = logits.sigmoid() |
| else: |
| raise NotImplementedError( |
| f"insupportable scoring function for MoE gating: {self.scoring_func}" |
| ) |
|
|
| if self.topk_method == "noaux_tc": |
| assert not self.training |
| scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) |
| group_scores = ( |
| scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) |
| ) |
| group_idx = torch.topk( |
| group_scores, k=self.topk_group, dim=-1, sorted=False |
| )[1] |
| group_mask = torch.zeros_like(group_scores) |
| group_mask.scatter_(1, group_idx, 1) |
| score_mask = ( |
| group_mask.unsqueeze(-1) |
| .expand( |
| bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group |
| ) |
| .reshape(bsz * seq_len, -1) |
| ) |
| tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) |
| _, topk_idx = torch.topk( |
| tmp_scores, k=self.top_k, dim=-1, sorted=False |
| ) |
| topk_weight = scores.gather(1, topk_idx) |
| else: |
| raise NotImplementedError( |
| f"insupportable TopK function for MoE gating: {self.topk_method}" |
| ) |
|
|
| if self.top_k > 1 and self.norm_topk_prob: |
| denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 |
| topk_weight = topk_weight / denominator |
| topk_weight = topk_weight * self.routed_scaling_factor |
|
|
| return topk_idx, topk_weight |
|
|
|
|
| class ScatterbrainMoonlightMoE(nn.Module): |
| """ |
| A mixed expert module using ScatterMoE for optimized expert computation. |
| |
| Uses stacked expert weights and ScatterMoE's parallel_linear for efficient |
| GPU computation instead of nn.ModuleList with sequential expert calls. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.num_experts_per_tok = config.num_experts_per_tok |
| self.n_routed_experts = config.n_routed_experts |
| self.hidden_size = config.hidden_size |
| self.moe_intermediate_size = config.moe_intermediate_size |
|
|
| |
| |
| self.expert_gate_proj = nn.Parameter( |
| torch.empty(config.n_routed_experts, config.moe_intermediate_size, config.hidden_size) |
| ) |
| self.expert_up_proj = nn.Parameter( |
| torch.empty(config.n_routed_experts, config.moe_intermediate_size, config.hidden_size) |
| ) |
| self.expert_down_proj = nn.Parameter( |
| torch.empty(config.n_routed_experts, config.hidden_size, config.moe_intermediate_size) |
| ) |
|
|
| |
| for param in [self.expert_gate_proj, self.expert_up_proj, self.expert_down_proj]: |
| nn.init.kaiming_uniform_(param, a=5**0.5) |
|
|
| self.gate = ScatterbrainMoonlightMoEGate(config) |
|
|
| if config.n_shared_experts is not None: |
| intermediate_size = config.moe_intermediate_size * config.n_shared_experts |
| self.shared_experts = ScatterbrainMoonlightMLP( |
| config=config, intermediate_size=intermediate_size |
| ) |
|
|
| def forward(self, hidden_states): |
| identity = hidden_states |
| orig_shape = hidden_states.shape |
| topk_idx, topk_weight = self.gate(hidden_states) |
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
|
|
| if SCATTERMOE_AVAILABLE: |
| y = self._forward_scattermoe(hidden_states, topk_idx, topk_weight) |
| else: |
| y = self._forward_loop(hidden_states, topk_idx, topk_weight) |
|
|
| y = y.view(*orig_shape) |
|
|
| if self.config.n_shared_experts is not None: |
| y = y + self.shared_experts(identity) |
| return y |
|
|
| def _forward_scattermoe( |
| self, |
| hidden_states: torch.Tensor, |
| selected_experts: torch.Tensor, |
| routing_weights: torch.Tensor, |
| ) -> torch.Tensor: |
| """Forward pass using ScatterMoE Triton kernels.""" |
| sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( |
| selected_experts, num_experts=self.n_routed_experts |
| ) |
|
|
| |
| gate_out = parallel_linear( |
| inputs=hidden_states, |
| expert_weights=self.expert_gate_proj.permute(0, 2, 1), |
| k=self.num_experts_per_tok, |
| sorted_expert_idxs=sorted_expert_idxs, |
| sorted_scattered_idxs=sorted_scattered_idxs, |
| expert_offsets=expert_offsets, |
| grouped_out=True, |
| ) |
|
|
| up_out = parallel_linear( |
| inputs=hidden_states, |
| expert_weights=self.expert_up_proj.permute(0, 2, 1), |
| k=self.num_experts_per_tok, |
| sorted_expert_idxs=sorted_expert_idxs, |
| sorted_scattered_idxs=sorted_scattered_idxs, |
| expert_offsets=expert_offsets, |
| grouped_out=True, |
| ) |
|
|
| |
| activated = F.silu(gate_out) * up_out |
|
|
| |
| |
| output = parallel_linear( |
| inputs=activated, |
| expert_weights=self.expert_down_proj.permute(0, 2, 1), |
| k=1, |
| sorted_expert_idxs=sorted_expert_idxs, |
| sorted_scattered_idxs=sorted_scattered_idxs, |
| expert_offsets=expert_offsets, |
| gates=routing_weights.to(self.expert_down_proj.dtype), |
| grouped_in=True, |
| grouped_out=False, |
| ) |
|
|
| return output |
|
|
| def _forward_loop( |
| self, |
| hidden_states: torch.Tensor, |
| selected_experts: torch.Tensor, |
| routing_weights: torch.Tensor, |
| ) -> torch.Tensor: |
| """Fallback forward pass using Python loop.""" |
| num_tokens = hidden_states.shape[0] |
|
|
| flat_expert_indices = selected_experts.view(-1) |
| flat_token_indices = torch.arange(num_tokens, device=hidden_states.device).unsqueeze(1).expand(-1, self.num_experts_per_tok).reshape(-1) |
| flat_routing_weights = routing_weights.view(-1) |
|
|
| sorted_indices = torch.argsort(flat_expert_indices, stable=True) |
| sorted_expert_indices = flat_expert_indices[sorted_indices] |
| sorted_token_indices = flat_token_indices[sorted_indices] |
| sorted_routing_weights = flat_routing_weights[sorted_indices] |
| sorted_hidden = hidden_states[sorted_token_indices] |
|
|
| expert_counts = torch.bincount(sorted_expert_indices, minlength=self.n_routed_experts) |
| expert_offsets = torch.zeros(self.n_routed_experts + 1, dtype=torch.long, device=hidden_states.device) |
| expert_offsets[1:] = torch.cumsum(expert_counts, dim=0) |
|
|
| final_hidden_states = torch.zeros_like(hidden_states) |
|
|
| for expert_idx in range(self.n_routed_experts): |
| start = expert_offsets[expert_idx].item() |
| end = expert_offsets[expert_idx + 1].item() |
|
|
| if start == end: |
| continue |
|
|
| expert_tokens = sorted_hidden[start:end] |
| expert_weights = sorted_routing_weights[start:end].unsqueeze(-1) |
| token_indices = sorted_token_indices[sorted_indices[start:end]] |
|
|
| gate = self.expert_gate_proj[expert_idx] |
| up = self.expert_up_proj[expert_idx] |
| down = self.expert_down_proj[expert_idx] |
|
|
| gate_out = F.linear(expert_tokens, gate) |
| up_out = F.linear(expert_tokens, up) |
| expert_out = F.linear(F.silu(gate_out) * up_out, down) |
|
|
| final_hidden_states.index_add_(0, token_indices, expert_out * expert_weights) |
|
|
| return final_hidden_states |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand( |
| batch, num_key_value_heads, n_rep, slen, head_dim |
| ) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| class ScatterbrainMoonlightAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: ScatterbrainMoonlightConfig, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| if layer_idx is None: |
| logger.warning_once( |
| f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " |
| "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| "when creating this class." |
| ) |
|
|
| self.attention_dropout = config.attention_dropout |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
|
|
| self.max_position_embeddings = config.max_position_embeddings |
| self.rope_theta = config.rope_theta |
| self.q_lora_rank = config.q_lora_rank |
| self.qk_rope_head_dim = config.qk_rope_head_dim |
| self.kv_lora_rank = config.kv_lora_rank |
| self.v_head_dim = config.v_head_dim |
| self.qk_nope_head_dim = config.qk_nope_head_dim |
| self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim |
|
|
| self.is_causal = True |
|
|
| if self.q_lora_rank is None: |
| self.q_proj = nn.Linear( |
| self.hidden_size, self.num_heads * self.q_head_dim, bias=False |
| ) |
| else: |
| self.q_a_proj = nn.Linear( |
| self.hidden_size, config.q_lora_rank, bias=config.attention_bias |
| ) |
| self.q_a_layernorm = ScatterbrainMoonlightRMSNorm(config.q_lora_rank) |
| self.q_b_proj = nn.Linear( |
| config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False |
| ) |
|
|
| self.kv_a_proj_with_mqa = nn.Linear( |
| self.hidden_size, |
| config.kv_lora_rank + config.qk_rope_head_dim, |
| bias=config.attention_bias, |
| ) |
| self.kv_a_layernorm = ScatterbrainMoonlightRMSNorm(config.kv_lora_rank) |
| self.kv_b_proj = nn.Linear( |
| config.kv_lora_rank, |
| self.num_heads |
| * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), |
| bias=False, |
| ) |
|
|
| self.o_proj = nn.Linear( |
| self.num_heads * self.v_head_dim, |
| self.hidden_size, |
| bias=config.attention_bias, |
| ) |
| self._init_rope() |
|
|
| self.softmax_scale = self.q_head_dim ** (-0.5) |
| if self.config.rope_scaling is not None: |
| mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) |
| scaling_factor = self.config.rope_scaling["factor"] |
| if mscale_all_dim: |
| mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) |
| self.softmax_scale = self.softmax_scale * mscale * mscale |
|
|
| def _init_rope(self): |
| if self.config.rope_scaling is None: |
| self.rotary_emb = ScatterbrainMoonlightRotaryEmbedding( |
| self.qk_rope_head_dim, |
| max_position_embeddings=self.max_position_embeddings, |
| base=self.rope_theta, |
| ) |
| else: |
| scaling_type = self.config.rope_scaling["type"] |
| scaling_factor = self.config.rope_scaling["factor"] |
| if scaling_type == "linear": |
| self.rotary_emb = ScatterbrainMoonlightLinearScalingRotaryEmbedding( |
| self.qk_rope_head_dim, |
| max_position_embeddings=self.max_position_embeddings, |
| scaling_factor=scaling_factor, |
| base=self.rope_theta, |
| ) |
| elif scaling_type == "dynamic": |
| self.rotary_emb = ScatterbrainMoonlightDynamicNTKScalingRotaryEmbedding( |
| self.qk_rope_head_dim, |
| max_position_embeddings=self.max_position_embeddings, |
| scaling_factor=scaling_factor, |
| base=self.rope_theta, |
| ) |
| elif scaling_type == "yarn": |
| kwargs = { |
| key: self.config.rope_scaling[key] |
| for key in [ |
| "original_max_position_embeddings", |
| "beta_fast", |
| "beta_slow", |
| "mscale", |
| "mscale_all_dim", |
| ] |
| if key in self.config.rope_scaling |
| } |
| self.rotary_emb = ScatterbrainMoonlightYarnRotaryEmbedding( |
| self.qk_rope_head_dim, |
| max_position_embeddings=self.max_position_embeddings, |
| scaling_factor=scaling_factor, |
| base=self.rope_theta, |
| **kwargs, |
| ) |
| else: |
| raise ValueError(f"Unknown RoPE scaling type {scaling_type}") |
|
|
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| return ( |
| tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if "padding_mask" in kwargs: |
| warnings.warn( |
| "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
| ) |
| bsz, q_len, _ = hidden_states.size() |
|
|
| if self.q_lora_rank is None: |
| q = self.q_proj(hidden_states) |
| else: |
| q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) |
| q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) |
| q_nope, q_pe = torch.split( |
| q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 |
| ) |
|
|
| compressed_kv = self.kv_a_proj_with_mqa(hidden_states) |
| compressed_kv, k_pe = torch.split( |
| compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 |
| ) |
| k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) |
| kv = ( |
| self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) |
| .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) |
| .transpose(1, 2) |
| ) |
|
|
| k_nope, value_states = torch.split( |
| kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 |
| ) |
| kv_seq_len = value_states.shape[-2] |
| if past_key_value is not None: |
| if self.layer_idx is None: |
| raise ValueError( |
| f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
| "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
| "with a layer index." |
| ) |
| kv_seq_len += _get_cache_length(past_key_value, kv_seq_len, self.layer_idx) |
| cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
|
|
| q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) |
|
|
| query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) |
| query_states[:, :, :, : self.qk_nope_head_dim] = q_nope |
| query_states[:, :, :, self.qk_nope_head_dim :] = q_pe |
|
|
| key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) |
| key_states[:, :, :, : self.qk_nope_head_dim] = k_nope |
| key_states[:, :, :, self.qk_nope_head_dim :] = k_pe |
| if past_key_value is not None: |
| cache_kwargs = {"sin": sin, "cos": cos} |
| key_states, value_states = past_key_value.update( |
| key_states, value_states, self.layer_idx, cache_kwargs |
| ) |
|
|
| attn_weights = ( |
| torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale |
| ) |
|
|
| if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
| f" {attn_weights.size()}" |
| ) |
| if attention_mask is not None: |
| if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
| ) |
| attn_weights = attn_weights + attention_mask |
|
|
| attn_weights = nn.functional.softmax( |
| attn_weights, dim=-1, dtype=torch.float32 |
| ).to(query_states.dtype) |
| attn_weights = nn.functional.dropout( |
| attn_weights, p=self.attention_dropout, training=self.training |
| ) |
| attn_output = torch.matmul(attn_weights, value_states) |
|
|
| if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) |
|
|
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
|
|
| class ScatterbrainMoonlightFlashAttention2(ScatterbrainMoonlightAttention): |
| """Flash Attention 2 implementation.""" |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if "padding_mask" in kwargs: |
| warnings.warn( |
| "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
| ) |
| attention_mask = kwargs.pop("padding_mask") |
|
|
| output_attentions = False |
|
|
| bsz, q_len, _ = hidden_states.size() |
|
|
| if self.q_lora_rank is None: |
| q = self.q_proj(hidden_states) |
| else: |
| q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) |
| q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) |
| q_nope, q_pe = torch.split( |
| q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 |
| ) |
|
|
| compressed_kv = self.kv_a_proj_with_mqa(hidden_states) |
| compressed_kv, k_pe = torch.split( |
| compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 |
| ) |
| k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) |
| kv = ( |
| self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) |
| .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) |
| .transpose(1, 2) |
| ) |
|
|
| k_nope, value_states = torch.split( |
| kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 |
| ) |
| kv_seq_len = value_states.shape[-2] |
|
|
| if past_key_value is not None: |
| kv_seq_len += _get_cache_length(past_key_value, kv_seq_len, self.layer_idx) |
|
|
| cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) |
|
|
| query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) |
| query_states[:, :, :, : self.qk_nope_head_dim] = q_nope |
| query_states[:, :, :, self.qk_nope_head_dim :] = q_pe |
|
|
| key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) |
| key_states[:, :, :, : self.qk_nope_head_dim] = k_nope |
| key_states[:, :, :, self.qk_nope_head_dim :] = k_pe |
|
|
| if self.q_head_dim != self.v_head_dim: |
| value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) |
|
|
| if past_key_value is not None: |
| cache_kwargs = {"sin": sin, "cos": cos} |
| key_states, value_states = past_key_value.update( |
| key_states, value_states, self.layer_idx, cache_kwargs |
| ) |
|
|
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
|
|
| dropout_rate = self.attention_dropout if self.training else 0.0 |
|
|
| input_dtype = query_states.dtype |
| if input_dtype == torch.float32: |
| if hasattr(self.config, "_pre_quantization_dtype"): |
| target_dtype = self.config._pre_quantization_dtype |
| elif torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| else: |
| target_dtype = ( |
| self.q_proj.weight.dtype |
| if self.q_lora_rank is None |
| else self.q_a_proj.weight.dtype |
| ) |
|
|
| logger.warning_once( |
| f"The input hidden states seems to be silently casted in float32, this might be related to" |
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
| f" {target_dtype}." |
| ) |
|
|
| query_states = query_states.to(target_dtype) |
| key_states = key_states.to(target_dtype) |
| value_states = value_states.to(target_dtype) |
|
|
| attn_output = self._flash_attention_forward( |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| q_len, |
| dropout=dropout_rate, |
| softmax_scale=self.softmax_scale, |
| ) |
| if self.q_head_dim != self.v_head_dim: |
| attn_output = attn_output[:, :, :, : self.v_head_dim] |
|
|
| attn_output = attn_output.reshape( |
| bsz, q_len, self.num_heads * self.v_head_dim |
| ).contiguous() |
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
| def _flash_attention_forward( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| query_length, |
| dropout=0.0, |
| softmax_scale=None, |
| ): |
| if not self._flash_attn_uses_top_left_mask: |
| causal = self.is_causal |
| else: |
| causal = self.is_causal and query_length != 1 |
|
|
| if attention_mask is not None: |
| batch_size = query_states.shape[0] |
| ( |
| query_states, |
| key_states, |
| value_states, |
| indices_q, |
| cu_seq_lens, |
| max_seq_lens, |
| ) = self._upad_input( |
| query_states, key_states, value_states, attention_mask, query_length |
| ) |
|
|
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
|
| attn_output_unpad = flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, |
| max_seqlen_k=max_seqlen_in_batch_k, |
| dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
|
|
| attn_output = pad_input( |
| attn_output_unpad, indices_q, batch_size, query_length |
| ) |
| else: |
| attn_output = flash_attn_func( |
| query_states, |
| key_states, |
| value_states, |
| dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
|
|
| return attn_output |
|
|
| def _upad_input( |
| self, query_layer, key_layer, value_layer, attention_mask, query_length |
| ): |
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
|
|
| key_layer = index_first_axis( |
| key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), |
| indices_k, |
| ) |
| value_layer = index_first_axis( |
| value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), |
| indices_k, |
| ) |
| if query_length == kv_seq_len: |
| query_layer = index_first_axis( |
| query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), |
| indices_k, |
| ) |
| cu_seqlens_q = cu_seqlens_k |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| indices_q = indices_k |
| elif query_length == 1: |
| max_seqlen_in_batch_q = 1 |
| cu_seqlens_q = torch.arange( |
| batch_size + 1, dtype=torch.int32, device=query_layer.device |
| ) |
| indices_q = cu_seqlens_q[:-1] |
| query_layer = query_layer.squeeze(1) |
| else: |
| attention_mask = attention_mask[:, -query_length:] |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( |
| query_layer, attention_mask |
| ) |
|
|
| return ( |
| query_layer, |
| key_layer, |
| value_layer, |
| indices_q, |
| (cu_seqlens_q, cu_seqlens_k), |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| ) |
|
|
|
|
| ATTENTION_CLASSES = { |
| "eager": ScatterbrainMoonlightAttention, |
| "flash_attention_2": ScatterbrainMoonlightFlashAttention2, |
| } |
|
|
|
|
| class ScatterbrainMoonlightDecoderLayer(nn.Module): |
| def __init__(self, config: ScatterbrainMoonlightConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
|
|
| self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( |
| config=config, layer_idx=layer_idx |
| ) |
|
|
| self.mlp = ( |
| ScatterbrainMoonlightMoE(config) |
| if ( |
| config.n_routed_experts is not None |
| and layer_idx >= config.first_k_dense_replace |
| and layer_idx % config.moe_layer_freq == 0 |
| ) |
| else ScatterbrainMoonlightMLP(config) |
| ) |
| self.input_layernorm = ScatterbrainMoonlightRMSNorm( |
| config.hidden_size, eps=config.rms_norm_eps |
| ) |
| self.post_attention_layernorm = ScatterbrainMoonlightRMSNorm( |
| config.hidden_size, eps=config.rms_norm_eps |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| **kwargs, |
| ) -> Tuple[ |
| torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] |
| ]: |
| if "padding_mask" in kwargs: |
| warnings.warn( |
| "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
| ) |
| residual = hidden_states |
|
|
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
|
|
| DeepseekV3_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| """ |
|
|
|
|
| class ScatterbrainMoonlightPreTrainedModel(PreTrainedModel): |
| config_class = ScatterbrainMoonlightConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["ScatterbrainMoonlightDecoderLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
| _supports_cache_class = True |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| class ScatterbrainMoonlightModel(ScatterbrainMoonlightPreTrainedModel): |
| def __init__(self, config: ScatterbrainMoonlightConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding( |
| config.vocab_size, config.hidden_size, self.padding_idx |
| ) |
| self.layers = nn.ModuleList( |
| [ |
| ScatterbrainMoonlightDecoderLayer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ] |
| ) |
| self._attn_implementation = config._attn_implementation |
| self.norm = ScatterbrainMoonlightRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| self.gradient_checkpointing = False |
|
|
| self.num_loop_iterations = getattr(config, 'num_loop_iterations', 1) |
| self._num_virtual_layers = 1 + self.num_loop_iterations + 1 |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| def _prepare_decoder_attention_mask( |
| self, attention_mask, input_shape, inputs_embeds, past_key_values_length |
| ): |
| combined_attention_mask = None |
| if input_shape[-1] > 1: |
| combined_attention_mask = _prepare_4d_causal_attention_mask( |
| attention_mask, |
| input_shape, |
| inputs_embeds, |
| past_key_values_length, |
| ) |
|
|
| if attention_mask is not None and combined_attention_mask is not None: |
| combined_attention_mask = combined_attention_mask.to(attention_mask.device) |
|
|
| return combined_attention_mask |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time" |
| ) |
| elif input_ids is not None: |
| batch_size, seq_length = input_ids.shape[:2] |
| elif inputs_embeds is not None: |
| batch_size, seq_length = inputs_embeds.shape[:2] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| past_key_values_length = 0 |
|
|
| if use_cache: |
| if past_key_values is None: |
| past_key_values = ExpandedDynamicCache(self._num_virtual_layers) |
| elif not isinstance(past_key_values, ExpandedDynamicCache): |
| past_key_values = ExpandedDynamicCache(self._num_virtual_layers) |
| past_key_values_length = past_key_values.get_seq_length(0) |
|
|
| if position_ids is None: |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
| position_ids = torch.arange( |
| past_key_values_length, |
| seq_length + past_key_values_length, |
| dtype=torch.long, |
| device=device, |
| ) |
| position_ids = position_ids.unsqueeze(0) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if self._attn_implementation == "flash_attention_2": |
| attention_mask = ( |
| attention_mask |
| if (attention_mask is not None and 0 in attention_mask) |
| else None |
| ) |
| else: |
| attention_mask = self._prepare_decoder_attention_mask( |
| attention_mask, |
| (batch_size, seq_length), |
| inputs_embeds, |
| past_key_values_length, |
| ) |
|
|
| hidden_states = inputs_embeds |
|
|
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = past_key_values if use_cache else None |
|
|
| virtual_layer_idx = 0 |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| layer = self.layers[0] |
| layer.self_attn.layer_idx = virtual_layer_idx |
|
|
| layer_outputs = layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| virtual_layer_idx += 1 |
|
|
| |
| for loop_iter in range(self.num_loop_iterations): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| layer = self.layers[1] |
| layer.self_attn.layer_idx = virtual_layer_idx |
|
|
| layer_outputs = layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| virtual_layer_idx += 1 |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| layer = self.layers[2] |
| layer.self_attn.layer_idx = virtual_layer_idx |
|
|
| layer_outputs = layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [hidden_states, next_decoder_cache, all_hidden_states, all_self_attns] |
| if v is not None |
| ) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_decoder_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| class ScatterbrainMoonlightForCausalLM(ScatterbrainMoonlightPreTrainedModel): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = ScatterbrainMoonlightModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| **kwargs, |
| ): |
| if past_key_values is not None: |
| past_length = past_key_values.get_seq_length(0) |
|
|
| if input_ids.shape[1] > past_length: |
| remove_prefix_length = past_length |
| else: |
| remove_prefix_length = input_ids.shape[1] - 1 |
|
|
| input_ids = input_ids[:, remove_prefix_length:] |
|
|
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values is not None: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| "attention_mask": attention_mask, |
| } |
| ) |
| return model_inputs |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| past_key_values.reorder_cache(beam_idx) |
| return past_key_values |
|
|