# modeling_llada.py import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_utils import PreTrainedModel from transformers import PretrainedConfig # --- 1) Config Sınıfı -------------------------------------------------------- class LLaDAConfig(PretrainedConfig): model_type = "llada" def __init__( self, vocab_size=50000, max_seq_len=512, d_model=128, n_layers=16, n_heads=8, dropout=0.1, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.max_seq_len = max_seq_len self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.d_head = d_model // n_heads self.d_ffn = 4 * d_model self.dropout = dropout # --- 2) PreTrainedModel Tabanı ------------------------------------------------ class LLaDAPreTrainedModel(PreTrainedModel): config_class = LLaDAConfig base_model_prefix = "llada" # --- 3) Alt Modüller --------------------------------------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): return self._norm(x.float()).type_as(x) * self.weight class RotaryPositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len=512): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(max_seq_len, device=inv_freq.device).type_as(inv_freq) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos", emb.cos()) self.register_buffer("sin", emb.sin()) def forward(self, x): return self.cos[: x.shape[2], :], self.sin[: x.shape[2], :] def apply_rotary(q, k, cos, sin): q2 = (q * cos) + (torch.cat([-q[..., 1::2], q[..., ::2]], -1) * sin) k2 = (k * cos) + (torch.cat([-k[..., 1::2], k[..., ::2]], -1) * sin) return q2, k2 class Attention(nn.Module): def __init__(self, config: LLaDAConfig): super().__init__() self.n_heads = config.n_heads self.d_head = config.d_head self.wq = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False) self.wk = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False) self.wv = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False) self.wo = nn.Linear(config.n_heads * config.d_head, config.d_model, bias=False) self.rotary = RotaryPositionalEmbedding(config.d_head, config.max_seq_len) def forward(self, x): b, seq, _ = x.size() q = self.wq(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2) k = self.wk(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2) v = self.wv(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2) cos, sin = self.rotary(q) q, k = apply_rotary(q, k, cos, sin) out = F.scaled_dot_product_attention(q, k, v, is_causal=False) out = out.transpose(1, 2).reshape(b, seq, -1) return self.wo(out) class FeedForward(nn.Module): def __init__(self, config: LLaDAConfig): super().__init__() self.w1 = nn.Linear(config.d_model, config.d_ffn, bias=False) self.w2 = nn.Linear(config.d_ffn, config.d_model, bias=False) self.w3 = nn.Linear(config.d_model, config.d_ffn, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, config: LLaDAConfig): super().__init__() self.attn = Attention(config) self.ff = FeedForward(config) self.norm1 = RMSNorm(config.d_model) self.norm2 = RMSNorm(config.d_model) def forward(self, x): h = x + self.attn(self.norm1(x)) return h + self.ff(self.norm2(h)) # --- 4) Ana Model Sınıfı ------------------------------------------------------ class LLaDA_Model(nn.Module): def __init__(self, config: LLaDAConfig): super().__init__() self.embed = nn.Embedding(config.vocab_size, config.d_model) self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) self.norm = RMSNorm(config.d_model) def forward(self, input_ids): x = self.embed(input_ids) for layer in self.layers: x = layer(x) return self.norm(x) # --- 5) LM Head ile CausalLM -------------------------------------------------- class LLaDA_ModelForCausalLM(LLaDAPreTrainedModel): def __init__(self, config: LLaDAConfig): super().__init__(config) self.llada = LLaDA_Model(config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) def forward(self, input_ids, **kwargs): hidden = self.llada(input_ids) logits = self.lm_head(hidden) return logits