import math from typing import Optional, Tuple, List import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_rslm import RSLMConfig class RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x): var = x.float().pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(var + self.eps).to(x.dtype) return x * self.weight def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rope(q, k, cos, sin): # q: [b, qh, t, d], k: [b, kvh, t, d] cos = cos[None, None, :, :] sin = sin[None, None, :, :] q = (q * cos) + (rotate_half(q) * sin) k = (k * cos) + (rotate_half(k) * sin) return q, k class RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=262144, base=1000000.0): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, position_ids, dtype, device): # Basit RoPE referansı. YaRN/LongRoPE eğitim kernelinde ayrıca iyileştirilmeli. inv_freq = self.inv_freq.to(device) freqs = torch.einsum("t,d->td", position_ids.float().to(device), inv_freq) emb = torch.cat([freqs, freqs], dim=-1) return emb.cos().to(dtype), emb.sin().to(dtype) class RSLMAttention(nn.Module): def __init__(self, config: RSLMConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_q_heads = config.num_q_heads self.num_kv_heads = config.num_kv_heads self.head_dim = config.head_dim self.window_size = config.window_size self.is_global = layer_idx in set(config.global_layers_0idx) self.q_proj = nn.Linear(config.hidden_size, config.num_q_heads * config.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_kv_heads * config.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_kv_heads * config.head_dim, bias=False) self.o_proj = nn.Linear(config.num_q_heads * config.head_dim, config.hidden_size, bias=False) self.rotary = RotaryEmbedding( config.head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) def _repeat_kv(self, x): # x: [b, kvh, t, d] -> [b, qh, t, d] if self.num_kv_heads == self.num_q_heads: return x repeat = self.num_q_heads // self.num_kv_heads return x.repeat_interleave(repeat, dim=1) def forward(self, x, position_ids=None, attention_mask=None, past_key_value=None, use_cache=False): bsz, seqlen, _ = x.shape q = self.q_proj(x).view(bsz, seqlen, self.num_q_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) if position_ids is None: past_len = 0 if past_key_value is None else past_key_value[0].shape[-2] position_ids = torch.arange(past_len, past_len + seqlen, device=x.device) cos, sin = self.rotary(position_ids, q.dtype, x.device) q, k = apply_rope(q, k, cos, sin) if past_key_value is not None: pk, pv = past_key_value k = torch.cat([pk, k], dim=-2) v = torch.cat([pv, v], dim=-2) # Local katmanlarda cache eviction if (not self.is_global) and self.config.evict_local_kv and k.shape[-2] > self.config.local_cache_keep: k = k[..., -self.config.local_cache_keep :, :] v = v[..., -self.config.local_cache_keep :, :] present = (k, v) if use_cache else None k_rep = self._repeat_kv(k) v_rep = self._repeat_kv(v) # Basit referans attention. Büyük 256K prefill için FlashAttention/custom kernel gerekir. attn_scores = torch.matmul(q, k_rep.transpose(-2, -1)) / math.sqrt(self.head_dim) q_len = q.shape[-2] k_len = k_rep.shape[-2] # Causal mask causal = torch.ones((q_len, k_len), dtype=torch.bool, device=x.device).tril(diagonal=k_len - q_len) # Local window mask if not self.is_global: q_positions = torch.arange(k_len - q_len, k_len, device=x.device)[:, None] k_positions = torch.arange(0, k_len, device=x.device)[None, :] local = k_positions >= (q_positions - self.window_size + 1) causal = causal & local attn_scores = attn_scores.masked_fill(~causal[None, None, :, :], torch.finfo(attn_scores.dtype).min) if attention_mask is not None: attn_scores = attn_scores + attention_mask attn_weights = F.softmax(attn_scores.float(), dim=-1).to(q.dtype) out = torch.matmul(attn_weights, v_rep) out = out.transpose(1, 2).contiguous().view(bsz, seqlen, self.num_q_heads * self.head_dim) return self.o_proj(out), present class RSLMMLP(nn.Module): def __init__(self, config: RSLMConfig): super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class RSLMBlock(nn.Module): def __init__(self, config: RSLMConfig, layer_idx: int): super().__init__() self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.attn = RSLMAttention(config, layer_idx) self.mlp = RSLMMLP(config) self.parallel_block = config.parallel_block def forward(self, x, position_ids=None, attention_mask=None, past_key_value=None, use_cache=False): n = self.norm(x) attn_out, present = self.attn(n, position_ids=position_ids, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache) mlp_out = self.mlp(n) x = x + attn_out + mlp_out return x, present class RSLMPreTrainedModel(PreTrainedModel): config_class = RSLMConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["RSLMBlock"] class RSLMModel(RSLMPreTrainedModel): def __init__(self, config: RSLMConfig): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([RSLMBlock(config, i) for i in range(config.num_layers)]) self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.post_init() def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False): x = self.embed_tokens(input_ids) presents = [] if use_cache else None if past_key_values is None: past_key_values = [None] * len(self.layers) if position_ids is None: position_ids = torch.arange(input_ids.shape[1], device=input_ids.device) for layer, pkv in zip(self.layers, past_key_values): x, present = layer(x, position_ids=position_ids, attention_mask=attention_mask, past_key_value=pkv, use_cache=use_cache) if use_cache: presents.append(present) x = self.norm(x) return x, presents class RSLMForCausalLM(RSLMPreTrainedModel): def __init__(self, config: RSLMConfig): super().__init__(config) self.model = RSLMModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.post_init() def forward(self, input_ids, labels=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, **kwargs): hidden, presents = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, ) logits = self.lm_head(hidden) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=presents)