llada-turkish / modeling_llada.py
Baki Şahin
Upload 2 files
6d9bc50 verified
# modeling_llada.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig
# --- 1) Config Sınıfı --------------------------------------------------------
class LLaDAConfig(PretrainedConfig):
model_type = "llada"
def __init__(
self,
vocab_size=50000,
max_seq_len=512,
d_model=128,
n_layers=16,
n_heads=8,
dropout=0.1,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.d_ffn = 4 * d_model
self.dropout = dropout
# --- 2) PreTrainedModel Tabanı ------------------------------------------------
class LLaDAPreTrainedModel(PreTrainedModel):
config_class = LLaDAConfig
base_model_prefix = "llada"
# --- 3) Alt Modüller ---------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len, device=inv_freq.device).type_as(inv_freq)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos())
self.register_buffer("sin", emb.sin())
def forward(self, x):
return self.cos[: x.shape[2], :], self.sin[: x.shape[2], :]
def apply_rotary(q, k, cos, sin):
q2 = (q * cos) + (torch.cat([-q[..., 1::2], q[..., ::2]], -1) * sin)
k2 = (k * cos) + (torch.cat([-k[..., 1::2], k[..., ::2]], -1) * sin)
return q2, k2
class Attention(nn.Module):
def __init__(self, config: LLaDAConfig):
super().__init__()
self.n_heads = config.n_heads
self.d_head = config.d_head
self.wq = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
self.wk = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
self.wv = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
self.wo = nn.Linear(config.n_heads * config.d_head, config.d_model, bias=False)
self.rotary = RotaryPositionalEmbedding(config.d_head, config.max_seq_len)
def forward(self, x):
b, seq, _ = x.size()
q = self.wq(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2)
k = self.wk(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2)
v = self.wv(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2)
cos, sin = self.rotary(q)
q, k = apply_rotary(q, k, cos, sin)
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
out = out.transpose(1, 2).reshape(b, seq, -1)
return self.wo(out)
class FeedForward(nn.Module):
def __init__(self, config: LLaDAConfig):
super().__init__()
self.w1 = nn.Linear(config.d_model, config.d_ffn, bias=False)
self.w2 = nn.Linear(config.d_ffn, config.d_model, bias=False)
self.w3 = nn.Linear(config.d_model, config.d_ffn, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, config: LLaDAConfig):
super().__init__()
self.attn = Attention(config)
self.ff = FeedForward(config)
self.norm1 = RMSNorm(config.d_model)
self.norm2 = RMSNorm(config.d_model)
def forward(self, x):
h = x + self.attn(self.norm1(x))
return h + self.ff(self.norm2(h))
# --- 4) Ana Model Sınıfı ------------------------------------------------------
class LLaDA_Model(nn.Module):
def __init__(self, config: LLaDAConfig):
super().__init__()
self.embed = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.norm = RMSNorm(config.d_model)
def forward(self, input_ids):
x = self.embed(input_ids)
for layer in self.layers:
x = layer(x)
return self.norm(x)
# --- 5) LM Head ile CausalLM --------------------------------------------------
class LLaDA_ModelForCausalLM(LLaDAPreTrainedModel):
def __init__(self, config: LLaDAConfig):
super().__init__(config)
self.llada = LLaDA_Model(config)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
def forward(self, input_ids, **kwargs):
hidden = self.llada(input_ids)
logits = self.lm_head(hidden)
return logits