| import math |
| from typing import Optional, Tuple, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| from .configuration_rslm import RSLMConfig |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, hidden_size: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| var = x.float().pow(2).mean(dim=-1, keepdim=True) |
| x = x * torch.rsqrt(var + self.eps).to(x.dtype) |
| return x * self.weight |
|
|
|
|
| def rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rope(q, k, cos, sin): |
| |
| cos = cos[None, None, :, :] |
| sin = sin[None, None, :, :] |
| q = (q * cos) + (rotate_half(q) * sin) |
| k = (k * cos) + (rotate_half(k) * sin) |
| return q, k |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_position_embeddings=262144, base=1000000.0): |
| super().__init__() |
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def forward(self, position_ids, dtype, device): |
| |
| inv_freq = self.inv_freq.to(device) |
| freqs = torch.einsum("t,d->td", position_ids.float().to(device), inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| return emb.cos().to(dtype), emb.sin().to(dtype) |
|
|
|
|
| class RSLMAttention(nn.Module): |
| def __init__(self, config: RSLMConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.hidden_size = config.hidden_size |
| self.num_q_heads = config.num_q_heads |
| self.num_kv_heads = config.num_kv_heads |
| self.head_dim = config.head_dim |
| self.window_size = config.window_size |
| self.is_global = layer_idx in set(config.global_layers_0idx) |
|
|
| self.q_proj = nn.Linear(config.hidden_size, config.num_q_heads * config.head_dim, bias=False) |
| self.k_proj = nn.Linear(config.hidden_size, config.num_kv_heads * config.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_kv_heads * config.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.num_q_heads * config.head_dim, config.hidden_size, bias=False) |
| self.rotary = RotaryEmbedding( |
| config.head_dim, |
| max_position_embeddings=config.max_position_embeddings, |
| base=config.rope_theta, |
| ) |
|
|
| def _repeat_kv(self, x): |
| |
| if self.num_kv_heads == self.num_q_heads: |
| return x |
| repeat = self.num_q_heads // self.num_kv_heads |
| return x.repeat_interleave(repeat, dim=1) |
|
|
| def forward(self, x, position_ids=None, attention_mask=None, past_key_value=None, use_cache=False): |
| bsz, seqlen, _ = x.shape |
|
|
| q = self.q_proj(x).view(bsz, seqlen, self.num_q_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| if position_ids is None: |
| past_len = 0 if past_key_value is None else past_key_value[0].shape[-2] |
| position_ids = torch.arange(past_len, past_len + seqlen, device=x.device) |
|
|
| cos, sin = self.rotary(position_ids, q.dtype, x.device) |
| q, k = apply_rope(q, k, cos, sin) |
|
|
| if past_key_value is not None: |
| pk, pv = past_key_value |
| k = torch.cat([pk, k], dim=-2) |
| v = torch.cat([pv, v], dim=-2) |
|
|
| |
| if (not self.is_global) and self.config.evict_local_kv and k.shape[-2] > self.config.local_cache_keep: |
| k = k[..., -self.config.local_cache_keep :, :] |
| v = v[..., -self.config.local_cache_keep :, :] |
|
|
| present = (k, v) if use_cache else None |
|
|
| k_rep = self._repeat_kv(k) |
| v_rep = self._repeat_kv(v) |
|
|
| |
| attn_scores = torch.matmul(q, k_rep.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
|
| q_len = q.shape[-2] |
| k_len = k_rep.shape[-2] |
|
|
| |
| causal = torch.ones((q_len, k_len), dtype=torch.bool, device=x.device).tril(diagonal=k_len - q_len) |
|
|
| |
| if not self.is_global: |
| q_positions = torch.arange(k_len - q_len, k_len, device=x.device)[:, None] |
| k_positions = torch.arange(0, k_len, device=x.device)[None, :] |
| local = k_positions >= (q_positions - self.window_size + 1) |
| causal = causal & local |
|
|
| attn_scores = attn_scores.masked_fill(~causal[None, None, :, :], torch.finfo(attn_scores.dtype).min) |
|
|
| if attention_mask is not None: |
| attn_scores = attn_scores + attention_mask |
|
|
| attn_weights = F.softmax(attn_scores.float(), dim=-1).to(q.dtype) |
| out = torch.matmul(attn_weights, v_rep) |
| out = out.transpose(1, 2).contiguous().view(bsz, seqlen, self.num_q_heads * self.head_dim) |
| return self.o_proj(out), present |
|
|
|
|
| class RSLMMLP(nn.Module): |
| def __init__(self, config: RSLMConfig): |
| super().__init__() |
| self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
|
|
| def forward(self, x): |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class RSLMBlock(nn.Module): |
| def __init__(self, config: RSLMConfig, layer_idx: int): |
| super().__init__() |
| self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) |
| self.attn = RSLMAttention(config, layer_idx) |
| self.mlp = RSLMMLP(config) |
| self.parallel_block = config.parallel_block |
|
|
| def forward(self, x, position_ids=None, attention_mask=None, past_key_value=None, use_cache=False): |
| n = self.norm(x) |
| attn_out, present = self.attn(n, position_ids=position_ids, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache) |
| mlp_out = self.mlp(n) |
| x = x + attn_out + mlp_out |
| return x, present |
|
|
|
|
| class RSLMPreTrainedModel(PreTrainedModel): |
| config_class = RSLMConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["RSLMBlock"] |
|
|
|
|
| class RSLMModel(RSLMPreTrainedModel): |
| def __init__(self, config: RSLMConfig): |
| super().__init__(config) |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.layers = nn.ModuleList([RSLMBlock(config, i) for i in range(config.num_layers)]) |
| self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) |
| self.post_init() |
|
|
| def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False): |
| x = self.embed_tokens(input_ids) |
| presents = [] if use_cache else None |
| if past_key_values is None: |
| past_key_values = [None] * len(self.layers) |
| if position_ids is None: |
| position_ids = torch.arange(input_ids.shape[1], device=input_ids.device) |
| for layer, pkv in zip(self.layers, past_key_values): |
| x, present = layer(x, position_ids=position_ids, attention_mask=attention_mask, past_key_value=pkv, use_cache=use_cache) |
| if use_cache: |
| presents.append(present) |
| x = self.norm(x) |
| return x, presents |
|
|
|
|
| class RSLMForCausalLM(RSLMPreTrainedModel): |
| def __init__(self, config: RSLMConfig): |
| super().__init__(config) |
| self.model = RSLMModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| if config.tie_word_embeddings: |
| self.lm_head.weight = self.model.embed_tokens.weight |
| self.post_init() |
|
|
| def forward(self, input_ids, labels=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, **kwargs): |
| hidden, presents = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| ) |
| logits = self.lm_head(hidden) |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
| return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=presents) |
|
|