Transformers
JonusNattapong commited on
Commit
d5b4be0
·
verified ·
1 Parent(s): fcd9d64

Create modeling_openthaiwilai.py

Browse files
Files changed (1) hide show
  1. modeling_openthaiwilai.py +274 -0
modeling_openthaiwilai.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import (
6
+ PreTrainedModel,
7
+ PretrainedConfig,
8
+ AutoConfig,
9
+ AutoModelForCausalLM
10
+ )
11
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
12
+ from transformers.generation.utils import GenerationMixin
13
+
14
+ # ------------------------------------------------------------
15
+ # 🧩 Rotary Positional Embedding (RoPE)
16
+ # ------------------------------------------------------------
17
+ def build_rope_cache(seq_len, head_dim, device):
18
+ half_dim = head_dim // 2
19
+ freq_seq = torch.arange(half_dim, device=device, dtype=torch.float32)
20
+ inv_freq = 1.0 / (10000 ** (freq_seq / half_dim))
21
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
22
+ freqs = torch.outer(t, inv_freq) # (seq_len, half_dim)
23
+ cos, sin = torch.cos(freqs), torch.sin(freqs)
24
+ return cos, sin
25
+
26
+ def apply_rope(x, cos, sin):
27
+ # x: (B, T, H, D)
28
+ B, T, H, D = x.shape
29
+ cos = cos[:T, :].unsqueeze(0).unsqueeze(2) # (1, T, 1, D/2)
30
+ sin = sin[:T, :].unsqueeze(0).unsqueeze(2)
31
+ x1 = x[..., ::2]
32
+ x2 = x[..., 1::2]
33
+ out = torch.cat([x1 * cos - x2 * sin,
34
+ x1 * sin + x2 * cos], dim=-1)
35
+ return out
36
+
37
+ # ------------------------------------------------------------
38
+ # 🧩 Config
39
+ # ------------------------------------------------------------
40
+ class OpenThaiWilaiConfig(PretrainedConfig):
41
+ model_type = "OpenThaiWilai"
42
+
43
+ def __init__(
44
+ self,
45
+ vocab_size=50000,
46
+ hidden_size=768,
47
+ num_layers=6,
48
+ num_heads=8,
49
+ num_key_value_heads=None,
50
+ num_experts=4,
51
+ top_k=2,
52
+ max_position_embeddings=2048,
53
+ intermediate_size=3072,
54
+ rope=True,
55
+ use_flashattn=True,
56
+ eos_token_id=None,
57
+ bos_token_id=None,
58
+ pad_token_id=None,
59
+ **kwargs
60
+ ):
61
+ super().__init__(
62
+ pad_token_id=pad_token_id,
63
+ bos_token_id=bos_token_id,
64
+ eos_token_id=eos_token_id,
65
+ **kwargs
66
+ )
67
+ self.vocab_size = vocab_size
68
+ self.hidden_size = hidden_size
69
+ self.num_layers = num_layers
70
+ self.num_hidden_layers = num_layers
71
+ self.num_heads = num_heads
72
+ self.num_key_value_heads = num_key_value_heads or num_heads
73
+ self.num_experts = num_experts
74
+ self.top_k = top_k
75
+ self.max_position_embeddings = max_position_embeddings
76
+ self.intermediate_size = intermediate_size
77
+ self.rope = rope
78
+ self.use_flashattn = use_flashattn
79
+
80
+ # ------------------------------------------------------------
81
+ # 🧩 Custom Components
82
+ # ------------------------------------------------------------
83
+ class RMSNorm(nn.Module):
84
+ def __init__(self, d, eps=1e-6):
85
+ super().__init__()
86
+ self.weight = nn.Parameter(torch.ones(d))
87
+ self.eps = eps
88
+ def forward(self, x):
89
+ norm = x.norm(dim=-1, keepdim=True) * (1.0 / math.sqrt(x.size(-1)))
90
+ return self.weight * x / (norm + self.eps)
91
+
92
+ class SwiGLU(nn.Module):
93
+ def __init__(self, d_model, d_ff):
94
+ super().__init__()
95
+ self.w1 = nn.Linear(d_model, d_ff)
96
+ self.w2 = nn.Linear(d_model, d_ff)
97
+ def forward(self, x):
98
+ return F.silu(self.w1(x)) * self.w2(x)
99
+
100
+ # ------------------------------------------------------------
101
+ # 🧩 Multi-Head Attention with RoPE + FlashAttention + GQA
102
+ # ------------------------------------------------------------
103
+ try:
104
+ from flash_attn import flash_attn_func
105
+ FLASH_AVAILABLE = True
106
+ except ImportError:
107
+ FLASH_AVAILABLE = False
108
+
109
+ class MultiHeadAttention(nn.Module):
110
+ def __init__(self, config: OpenThaiWilaiConfig):
111
+ super().__init__()
112
+ self.num_heads = config.num_heads
113
+ self.num_kv_heads = config.num_key_value_heads
114
+ self.head_dim = config.hidden_size // config.num_heads
115
+
116
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
117
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim)
118
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim)
119
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)
120
+
121
+ self.rope = config.rope
122
+ self.use_flash = config.use_flashattn
123
+
124
+ def forward(self, x, attention_mask=None):
125
+ B, T, C = x.shape
126
+ q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim)
127
+ k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim)
128
+ v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim)
129
+
130
+ # RoPE
131
+ if self.rope:
132
+ cos, sin = build_rope_cache(T, self.head_dim, x.device)
133
+ q = apply_rope(q, cos, sin)
134
+ k = apply_rope(k, cos, sin)
135
+
136
+ # GQA
137
+ if self.num_kv_heads != self.num_heads:
138
+ k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
139
+ v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
140
+
141
+ # FlashAttention fallback
142
+ if self.use_flash and FLASH_AVAILABLE and torch.cuda.get_device_capability()[0] >= 8:
143
+ q = q.permute(0, 2, 1, 3) # (B, H, T, D)
144
+ k = k.permute(0, 2, 1, 3)
145
+ v = v.permute(0, 2, 1, 3)
146
+ out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True)
147
+ out = out.permute(0, 2, 1, 3).reshape(B, T, C)
148
+ else:
149
+ q = q.transpose(1, 2) # (B, H, T, D)
150
+ k = k.transpose(1, 2)
151
+ v = v.transpose(1, 2)
152
+ attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
153
+ if attention_mask is not None:
154
+ attn = attn.masked_fill(attention_mask == 0, float("-inf"))
155
+ attn = F.softmax(attn, dim=-1)
156
+ out = attn @ v
157
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
158
+
159
+ return self.o_proj(out)
160
+
161
+ # ------------------------------------------------------------
162
+ # 🧩 MoE with load balancing
163
+ # ------------------------------------------------------------
164
+ class MoE(nn.Module):
165
+ def __init__(self, config: OpenThaiWilaiConfig):
166
+ super().__init__()
167
+ self.experts = nn.ModuleList([
168
+ SwiGLU(config.hidden_size, config.intermediate_size) for _ in range(config.num_experts)
169
+ ])
170
+ self.gate = nn.Linear(config.hidden_size, config.num_experts)
171
+ self.top_k = config.top_k
172
+ self.num_experts = config.num_experts
173
+
174
+ def forward(self, x):
175
+ B, T, C = x.shape
176
+ scores = F.softmax(self.gate(x), dim=-1)
177
+ current_top_k = min(self.top_k, self.num_experts)
178
+ topk_scores, topk_idx = torch.topk(scores, current_top_k, dim=-1)
179
+ expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)
180
+ topk_idx_expanded = topk_idx.unsqueeze(-1).expand(-1, -1, -1, C)
181
+ selected_expert_outputs = torch.gather(expert_outputs, dim=2, index=topk_idx_expanded)
182
+ topk_scores_expanded = topk_scores.unsqueeze(-1).expand(-1, -1, -1, C)
183
+ weighted_expert_outputs = selected_expert_outputs * topk_scores_expanded
184
+
185
+ aux_loss = (scores.mean(0).var(dim=-1)).mean()
186
+ self.last_aux_loss = aux_loss
187
+
188
+ return torch.sum(weighted_expert_outputs, dim=2)
189
+
190
+ # ------------------------------------------------------------
191
+ # 🧩 Transformer Block
192
+ # ------------------------------------------------------------
193
+ class Block(nn.Module):
194
+ def __init__(self, config: OpenThaiWilaiConfig):
195
+ super().__init__()
196
+ self.ln1 = RMSNorm(config.hidden_size)
197
+ self.attn = MultiHeadAttention(config)
198
+ self.ln2 = RMSNorm(config.hidden_size)
199
+ self.moe = MoE(config)
200
+ def forward(self, x):
201
+ x = x + self.attn(self.ln1(x))
202
+ x = x + self.moe(self.ln2(x))
203
+ return x
204
+
205
+ # ------------------------------------------------------------
206
+ # 🧩 OpenThaiWilai For Causal LM
207
+ # ------------------------------------------------------------
208
+ class OpenThaiWilaiForCausalLM(PreTrainedModel, GenerationMixin):
209
+ config_class = OpenThaiWilaiConfig
210
+ _keys_to_ignore_on_save = []
211
+ _dynamic_tied_weights_keys = {"lm_head.weight", "embed.weight"}
212
+
213
+ def __init__(self, config: OpenThaiWilaiConfig):
214
+ super().__init__(config)
215
+ self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
216
+ self.pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
217
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])
218
+ self.ln_f = RMSNorm(config.hidden_size)
219
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
220
+
221
+ self.post_init()
222
+ self.tie_weights()
223
+
224
+ def tie_weights(self):
225
+ self.lm_head.weight = self.embed.weight
226
+
227
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
228
+ return {"input_ids": input_ids, "past_key_values": past_key_values}
229
+
230
+ def forward(
231
+ self,
232
+ input_ids,
233
+ labels=None,
234
+ attention_mask=None,
235
+ past_key_values=None,
236
+ use_cache: bool = False,
237
+ **kwargs
238
+ ):
239
+ B, T = input_ids.shape
240
+ pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
241
+ x = self.embed(input_ids) + self.pos_embed(pos)
242
+ for block in self.blocks:
243
+ x = block(x)
244
+ x = self.ln_f(x)
245
+ logits = self.lm_head(x)
246
+
247
+ loss = None
248
+ aux_loss = 0
249
+ for block in self.blocks:
250
+ if hasattr(block.moe, "last_aux_loss"):
251
+ aux_loss += block.moe.last_aux_loss
252
+
253
+ if labels is not None:
254
+ ce_loss = F.cross_entropy(
255
+ logits.view(-1, logits.size(-1)),
256
+ labels.view(-1),
257
+ ignore_index=-100
258
+ )
259
+ loss = ce_loss + 0.01 * aux_loss
260
+
261
+ return CausalLMOutputWithCrossAttentions(
262
+ loss=loss,
263
+ logits=logits,
264
+ past_key_values=past_key_values if use_cache else None,
265
+ hidden_states=None,
266
+ attentions=None,
267
+ )
268
+
269
+
270
+ # ------------------------------------------------------------
271
+ # 🧩 Register model for Auto classes
272
+ # ------------------------------------------------------------
273
+ AutoConfig.register("OpenThaiWilai", OpenThaiWilaiConfig)
274
+ AutoModelForCausalLM.register(OpenThaiWilaiConfig, OpenThaiWilaiForCausalLM)