new-model / modeling_rslm.py
Efe2898's picture
Add RSLM-1B-Speed architecture and weights
e5c4555 verified
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_rslm import RSLMConfig
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
var = x.float().pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps).to(x.dtype)
return x * self.weight
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rope(q, k, cos, sin):
# q: [b, qh, t, d], k: [b, kvh, t, d]
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
return q, k
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=262144, base=1000000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, position_ids, dtype, device):
# Basit RoPE referansı. YaRN/LongRoPE eğitim kernelinde ayrıca iyileştirilmeli.
inv_freq = self.inv_freq.to(device)
freqs = torch.einsum("t,d->td", position_ids.float().to(device), inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
return emb.cos().to(dtype), emb.sin().to(dtype)
class RSLMAttention(nn.Module):
def __init__(self, config: RSLMConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_q_heads = config.num_q_heads
self.num_kv_heads = config.num_kv_heads
self.head_dim = config.head_dim
self.window_size = config.window_size
self.is_global = layer_idx in set(config.global_layers_0idx)
self.q_proj = nn.Linear(config.hidden_size, config.num_q_heads * config.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_kv_heads * config.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_kv_heads * config.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_q_heads * config.head_dim, config.hidden_size, bias=False)
self.rotary = RotaryEmbedding(
config.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
def _repeat_kv(self, x):
# x: [b, kvh, t, d] -> [b, qh, t, d]
if self.num_kv_heads == self.num_q_heads:
return x
repeat = self.num_q_heads // self.num_kv_heads
return x.repeat_interleave(repeat, dim=1)
def forward(self, x, position_ids=None, attention_mask=None, past_key_value=None, use_cache=False):
bsz, seqlen, _ = x.shape
q = self.q_proj(x).view(bsz, seqlen, self.num_q_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
if position_ids is None:
past_len = 0 if past_key_value is None else past_key_value[0].shape[-2]
position_ids = torch.arange(past_len, past_len + seqlen, device=x.device)
cos, sin = self.rotary(position_ids, q.dtype, x.device)
q, k = apply_rope(q, k, cos, sin)
if past_key_value is not None:
pk, pv = past_key_value
k = torch.cat([pk, k], dim=-2)
v = torch.cat([pv, v], dim=-2)
# Local katmanlarda cache eviction
if (not self.is_global) and self.config.evict_local_kv and k.shape[-2] > self.config.local_cache_keep:
k = k[..., -self.config.local_cache_keep :, :]
v = v[..., -self.config.local_cache_keep :, :]
present = (k, v) if use_cache else None
k_rep = self._repeat_kv(k)
v_rep = self._repeat_kv(v)
# Basit referans attention. Büyük 256K prefill için FlashAttention/custom kernel gerekir.
attn_scores = torch.matmul(q, k_rep.transpose(-2, -1)) / math.sqrt(self.head_dim)
q_len = q.shape[-2]
k_len = k_rep.shape[-2]
# Causal mask
causal = torch.ones((q_len, k_len), dtype=torch.bool, device=x.device).tril(diagonal=k_len - q_len)
# Local window mask
if not self.is_global:
q_positions = torch.arange(k_len - q_len, k_len, device=x.device)[:, None]
k_positions = torch.arange(0, k_len, device=x.device)[None, :]
local = k_positions >= (q_positions - self.window_size + 1)
causal = causal & local
attn_scores = attn_scores.masked_fill(~causal[None, None, :, :], torch.finfo(attn_scores.dtype).min)
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores.float(), dim=-1).to(q.dtype)
out = torch.matmul(attn_weights, v_rep)
out = out.transpose(1, 2).contiguous().view(bsz, seqlen, self.num_q_heads * self.head_dim)
return self.o_proj(out), present
class RSLMMLP(nn.Module):
def __init__(self, config: RSLMConfig):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class RSLMBlock(nn.Module):
def __init__(self, config: RSLMConfig, layer_idx: int):
super().__init__()
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.attn = RSLMAttention(config, layer_idx)
self.mlp = RSLMMLP(config)
self.parallel_block = config.parallel_block
def forward(self, x, position_ids=None, attention_mask=None, past_key_value=None, use_cache=False):
n = self.norm(x)
attn_out, present = self.attn(n, position_ids=position_ids, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache)
mlp_out = self.mlp(n)
x = x + attn_out + mlp_out
return x, present
class RSLMPreTrainedModel(PreTrainedModel):
config_class = RSLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["RSLMBlock"]
class RSLMModel(RSLMPreTrainedModel):
def __init__(self, config: RSLMConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([RSLMBlock(config, i) for i in range(config.num_layers)])
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_init()
def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False):
x = self.embed_tokens(input_ids)
presents = [] if use_cache else None
if past_key_values is None:
past_key_values = [None] * len(self.layers)
if position_ids is None:
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device)
for layer, pkv in zip(self.layers, past_key_values):
x, present = layer(x, position_ids=position_ids, attention_mask=attention_mask, past_key_value=pkv, use_cache=use_cache)
if use_cache:
presents.append(present)
x = self.norm(x)
return x, presents
class RSLMForCausalLM(RSLMPreTrainedModel):
def __init__(self, config: RSLMConfig):
super().__init__(config)
self.model = RSLMModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.post_init()
def forward(self, input_ids, labels=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, **kwargs):
hidden, presents = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
)
logits = self.lm_head(hidden)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=presents)