|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
|
|
|
class LLaDAConfig(PretrainedConfig):
|
|
|
model_type = "llada"
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
vocab_size=50000,
|
|
|
max_seq_len=512,
|
|
|
d_model=128,
|
|
|
n_layers=16,
|
|
|
n_heads=8,
|
|
|
dropout=0.1,
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
self.vocab_size = vocab_size
|
|
|
self.max_seq_len = max_seq_len
|
|
|
self.d_model = d_model
|
|
|
self.n_layers = n_layers
|
|
|
self.n_heads = n_heads
|
|
|
self.d_head = d_model // n_heads
|
|
|
self.d_ffn = 4 * d_model
|
|
|
self.dropout = dropout
|
|
|
|
|
|
|
|
|
|
|
|
class LLaDAPreTrainedModel(PreTrainedModel):
|
|
|
config_class = LLaDAConfig
|
|
|
base_model_prefix = "llada"
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.eps = eps
|
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
|
|
|
|
def _norm(self, x):
|
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self._norm(x.float()).type_as(x) * self.weight
|
|
|
|
|
|
|
|
|
class RotaryPositionalEmbedding(nn.Module):
|
|
|
def __init__(self, dim, max_seq_len=512):
|
|
|
super().__init__()
|
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
t = torch.arange(max_seq_len, device=inv_freq.device).type_as(inv_freq)
|
|
|
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
|
self.register_buffer("cos", emb.cos())
|
|
|
self.register_buffer("sin", emb.sin())
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.cos[: x.shape[2], :], self.sin[: x.shape[2], :]
|
|
|
|
|
|
|
|
|
def apply_rotary(q, k, cos, sin):
|
|
|
q2 = (q * cos) + (torch.cat([-q[..., 1::2], q[..., ::2]], -1) * sin)
|
|
|
k2 = (k * cos) + (torch.cat([-k[..., 1::2], k[..., ::2]], -1) * sin)
|
|
|
return q2, k2
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
def __init__(self, config: LLaDAConfig):
|
|
|
super().__init__()
|
|
|
self.n_heads = config.n_heads
|
|
|
self.d_head = config.d_head
|
|
|
|
|
|
self.wq = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
|
|
|
self.wk = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
|
|
|
self.wv = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
|
|
|
self.wo = nn.Linear(config.n_heads * config.d_head, config.d_model, bias=False)
|
|
|
|
|
|
self.rotary = RotaryPositionalEmbedding(config.d_head, config.max_seq_len)
|
|
|
|
|
|
def forward(self, x):
|
|
|
b, seq, _ = x.size()
|
|
|
q = self.wq(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2)
|
|
|
k = self.wk(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2)
|
|
|
v = self.wv(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2)
|
|
|
cos, sin = self.rotary(q)
|
|
|
q, k = apply_rotary(q, k, cos, sin)
|
|
|
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
|
|
out = out.transpose(1, 2).reshape(b, seq, -1)
|
|
|
return self.wo(out)
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
def __init__(self, config: LLaDAConfig):
|
|
|
super().__init__()
|
|
|
self.w1 = nn.Linear(config.d_model, config.d_ffn, bias=False)
|
|
|
self.w2 = nn.Linear(config.d_ffn, config.d_model, bias=False)
|
|
|
self.w3 = nn.Linear(config.d_model, config.d_ffn, bias=False)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
def __init__(self, config: LLaDAConfig):
|
|
|
super().__init__()
|
|
|
self.attn = Attention(config)
|
|
|
self.ff = FeedForward(config)
|
|
|
self.norm1 = RMSNorm(config.d_model)
|
|
|
self.norm2 = RMSNorm(config.d_model)
|
|
|
|
|
|
def forward(self, x):
|
|
|
h = x + self.attn(self.norm1(x))
|
|
|
return h + self.ff(self.norm2(h))
|
|
|
|
|
|
|
|
|
|
|
|
class LLaDA_Model(nn.Module):
|
|
|
def __init__(self, config: LLaDAConfig):
|
|
|
super().__init__()
|
|
|
self.embed = nn.Embedding(config.vocab_size, config.d_model)
|
|
|
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
|
|
|
self.norm = RMSNorm(config.d_model)
|
|
|
|
|
|
def forward(self, input_ids):
|
|
|
x = self.embed(input_ids)
|
|
|
for layer in self.layers:
|
|
|
x = layer(x)
|
|
|
return self.norm(x)
|
|
|
|
|
|
|
|
|
|
|
|
class LLaDA_ModelForCausalLM(LLaDAPreTrainedModel):
|
|
|
def __init__(self, config: LLaDAConfig):
|
|
|
super().__init__(config)
|
|
|
self.llada = LLaDA_Model(config)
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
|
|
|
|
def forward(self, input_ids, **kwargs):
|
|
|
hidden = self.llada(input_ids)
|
|
|
logits = self.lm_head(hidden)
|
|
|
return logits
|
|
|
|