Baki Şahin commited on
Commit
6d9bc50
·
verified ·
1 Parent(s): 9bea81d

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.json +10 -0
  2. modeling_llada.py +148 -0
config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["LLaDA_ModelForCausalLM"],
3
+ "model_type": "llada",
4
+ "vocab_size": 5000,
5
+ "max_seq_len": 512,
6
+ "d_model": 128,
7
+ "n_layers": 8,
8
+ "n_heads": 8,
9
+ "dropout": 0.1
10
+ }
modeling_llada.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_llada.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.modeling_utils import PreTrainedModel
7
+ from transformers import PretrainedConfig
8
+
9
+ # --- 1) Config Sınıfı --------------------------------------------------------
10
+ class LLaDAConfig(PretrainedConfig):
11
+ model_type = "llada"
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size=50000,
16
+ max_seq_len=512,
17
+ d_model=128,
18
+ n_layers=16,
19
+ n_heads=8,
20
+ dropout=0.1,
21
+ **kwargs
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.vocab_size = vocab_size
25
+ self.max_seq_len = max_seq_len
26
+ self.d_model = d_model
27
+ self.n_layers = n_layers
28
+ self.n_heads = n_heads
29
+ self.d_head = d_model // n_heads
30
+ self.d_ffn = 4 * d_model
31
+ self.dropout = dropout
32
+
33
+
34
+ # --- 2) PreTrainedModel Tabanı ------------------------------------------------
35
+ class LLaDAPreTrainedModel(PreTrainedModel):
36
+ config_class = LLaDAConfig
37
+ base_model_prefix = "llada"
38
+
39
+
40
+ # --- 3) Alt Modüller ---------------------------------------------------------
41
+ class RMSNorm(nn.Module):
42
+ def __init__(self, dim: int, eps: float = 1e-6):
43
+ super().__init__()
44
+ self.eps = eps
45
+ self.weight = nn.Parameter(torch.ones(dim))
46
+
47
+ def _norm(self, x):
48
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
49
+
50
+ def forward(self, x):
51
+ return self._norm(x.float()).type_as(x) * self.weight
52
+
53
+
54
+ class RotaryPositionalEmbedding(nn.Module):
55
+ def __init__(self, dim, max_seq_len=512):
56
+ super().__init__()
57
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
58
+ t = torch.arange(max_seq_len, device=inv_freq.device).type_as(inv_freq)
59
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
60
+ emb = torch.cat((freqs, freqs), dim=-1)
61
+ self.register_buffer("cos", emb.cos())
62
+ self.register_buffer("sin", emb.sin())
63
+
64
+ def forward(self, x):
65
+ return self.cos[: x.shape[2], :], self.sin[: x.shape[2], :]
66
+
67
+
68
+ def apply_rotary(q, k, cos, sin):
69
+ q2 = (q * cos) + (torch.cat([-q[..., 1::2], q[..., ::2]], -1) * sin)
70
+ k2 = (k * cos) + (torch.cat([-k[..., 1::2], k[..., ::2]], -1) * sin)
71
+ return q2, k2
72
+
73
+
74
+ class Attention(nn.Module):
75
+ def __init__(self, config: LLaDAConfig):
76
+ super().__init__()
77
+ self.n_heads = config.n_heads
78
+ self.d_head = config.d_head
79
+
80
+ self.wq = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
81
+ self.wk = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
82
+ self.wv = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
83
+ self.wo = nn.Linear(config.n_heads * config.d_head, config.d_model, bias=False)
84
+
85
+ self.rotary = RotaryPositionalEmbedding(config.d_head, config.max_seq_len)
86
+
87
+ def forward(self, x):
88
+ b, seq, _ = x.size()
89
+ q = self.wq(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2)
90
+ k = self.wk(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2)
91
+ v = self.wv(x).view(b, seq, self.n_heads, self.d_head).transpose(1, 2)
92
+ cos, sin = self.rotary(q)
93
+ q, k = apply_rotary(q, k, cos, sin)
94
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
95
+ out = out.transpose(1, 2).reshape(b, seq, -1)
96
+ return self.wo(out)
97
+
98
+
99
+ class FeedForward(nn.Module):
100
+ def __init__(self, config: LLaDAConfig):
101
+ super().__init__()
102
+ self.w1 = nn.Linear(config.d_model, config.d_ffn, bias=False)
103
+ self.w2 = nn.Linear(config.d_ffn, config.d_model, bias=False)
104
+ self.w3 = nn.Linear(config.d_model, config.d_ffn, bias=False)
105
+
106
+ def forward(self, x):
107
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
108
+
109
+
110
+ class TransformerBlock(nn.Module):
111
+ def __init__(self, config: LLaDAConfig):
112
+ super().__init__()
113
+ self.attn = Attention(config)
114
+ self.ff = FeedForward(config)
115
+ self.norm1 = RMSNorm(config.d_model)
116
+ self.norm2 = RMSNorm(config.d_model)
117
+
118
+ def forward(self, x):
119
+ h = x + self.attn(self.norm1(x))
120
+ return h + self.ff(self.norm2(h))
121
+
122
+
123
+ # --- 4) Ana Model Sınıfı ------------------------------------------------------
124
+ class LLaDA_Model(nn.Module):
125
+ def __init__(self, config: LLaDAConfig):
126
+ super().__init__()
127
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
128
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
129
+ self.norm = RMSNorm(config.d_model)
130
+
131
+ def forward(self, input_ids):
132
+ x = self.embed(input_ids)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return self.norm(x)
136
+
137
+
138
+ # --- 5) LM Head ile CausalLM --------------------------------------------------
139
+ class LLaDA_ModelForCausalLM(LLaDAPreTrainedModel):
140
+ def __init__(self, config: LLaDAConfig):
141
+ super().__init__(config)
142
+ self.llada = LLaDA_Model(config)
143
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
144
+
145
+ def forward(self, input_ids, **kwargs):
146
+ hidden = self.llada(input_ids)
147
+ logits = self.lm_head(hidden)
148
+ return logits