Syamsuddin commited on
Commit
bb09e93
·
verified ·
1 Parent(s): 6589101

Update src/model.py

Browse files
Files changed (1) hide show
  1. src/model.py +496 -0
src/model.py CHANGED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025
2
+ # G-Transformer: Energy-Efficient Transformer based on GIT
3
+ # Author: Syamsuddin B. Ideris, S.Pd.MM
4
+
5
+ import math
6
+ from typing import Optional, Tuple, List, Dict, Any
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ try:
13
+ from transformers import PreTrainedModel, PretrainedConfig
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+ except Exception as e:
16
+ raise ImportError(
17
+ "Harap instal transformers >= 4.40.0. "
18
+ "pip install transformers"
19
+ ) from e
20
+
21
+
22
+ # ----------------------------
23
+ # Konfigurasi
24
+ # ----------------------------
25
+ class GTransformerConfig(PretrainedConfig):
26
+ model_type = "gtransformer"
27
+
28
+ def __init__(
29
+ self,
30
+ vocab_size: int = 65536,
31
+ hidden_size: int = 8192,
32
+ intermediate_size: int = 22016,
33
+ num_hidden_layers: int = 48,
34
+ num_attention_heads: int = 64,
35
+ max_position_embeddings: int = 65536,
36
+ hidden_act: str = "swiglu",
37
+ layer_norm_epsilon: float = 1e-5,
38
+ attention_dropout: float = 0.05,
39
+ hidden_dropout_prob: float = 0.05,
40
+ rotary_emb_base: int = 10000,
41
+ use_flash_attention: bool = True,
42
+ use_low_rank_ffn: bool = True,
43
+ use_entropy_gate: bool = True,
44
+ use_moe: bool = False,
45
+ num_experts: int = 0,
46
+ top_k_experts: int = 0,
47
+ fp8_precision: bool = False,
48
+ dvfs_enabled: bool = False,
49
+ informational_constant_kI: float = 2.612e-20,
50
+ energy_per_token_target_J: float = 0.07,
51
+ delta_I_gate: float = 0.75,
52
+ local_window: int = 512,
53
+ global_rank: int = 64,
54
+ kv_compression_rank: int = 64,
55
+ bos_token_id: int = 1,
56
+ eos_token_id: int = 2,
57
+ pad_token_id: int = 0,
58
+ **kwargs,
59
+ ):
60
+ super().__init__(**kwargs)
61
+ self.vocab_size = vocab_size
62
+ self.hidden_size = hidden_size
63
+ self.intermediate_size = intermediate_size
64
+ self.num_hidden_layers = num_hidden_layers
65
+ self.num_attention_heads = num_attention_heads
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.hidden_act = hidden_act
68
+ self.layer_norm_epsilon = layer_norm_epsilon
69
+ self.attention_dropout = attention_dropout
70
+ self.hidden_dropout_prob = hidden_dropout_prob
71
+ self.rotary_emb_base = rotary_emb_base
72
+
73
+ self.use_flash_attention = use_flash_attention
74
+ self.use_low_rank_ffn = use_low_rank_ffn
75
+ self.use_entropy_gate = use_entropy_gate
76
+
77
+ self.use_moe = use_moe
78
+ self.num_experts = num_experts
79
+ self.top_k_experts = top_k_experts
80
+
81
+ self.fp8_precision = fp8_precision
82
+ self.dvfs_enabled = dvfs_enabled
83
+
84
+ self.informational_constant_kI = informational_constant_kI
85
+ self.energy_per_token_target_J = energy_per_token_target_J
86
+
87
+ self.delta_I_gate = delta_I_gate
88
+ self.local_window = local_window
89
+ self.global_rank = global_rank
90
+ self.kv_compression_rank = kv_compression_rank
91
+
92
+ self.bos_token_id = bos_token_id
93
+ self.eos_token_id = eos_token_id
94
+ self.pad_token_id = pad_token_id
95
+
96
+
97
+ # ----------------------------
98
+ # Utilitas
99
+ # ----------------------------
100
+ def swiglu(x: torch.Tensor) -> torch.Tensor:
101
+ x1, x2 = x.chunk(2, dim=-1)
102
+ return F.silu(x1) * x2
103
+
104
+
105
+ def build_activation(name: str):
106
+ if name.lower() == "swiglu":
107
+ return swiglu
108
+ return getattr(F, name)
109
+
110
+
111
+ # Rotary posisi sederhana
112
+ class RotaryEmbedding(nn.Module):
113
+ def __init__(self, dim: int, base: int = 10000):
114
+ super().__init__()
115
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
116
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
117
+
118
+ def forward(self, x: torch.Tensor, seq_len: int):
119
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
120
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
121
+ emb = torch.cat((freqs, freqs), dim=-1)
122
+ cos = emb.cos()[None, None, :, :]
123
+ sin = emb.sin()[None, None, :, :]
124
+ return cos, sin
125
+
126
+
127
+ def apply_rotary(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
128
+ # q,k: [B, H, T, D]
129
+ def rotate(x):
130
+ x1, x2 = x[..., ::2], x[..., 1::2]
131
+ x_rot = torch.stack((-x2, x1), dim=-1).flatten(-2)
132
+ return x_rot
133
+ q_rot = (q * cos) + (rotate(q) * sin)
134
+ k_rot = (k * cos) + (rotate(k) * sin)
135
+ return q_rot, k_rot
136
+
137
+
138
+ # ----------------------------
139
+ # IA-Attention
140
+ # ----------------------------
141
+ class InformationalAttention(nn.Module):
142
+ """
143
+ Atensi hemat energi.
144
+ 1. Atensi lokal dengan jendela w.
145
+ 2. Seleksi token global berbasis skor informasi.
146
+ 3. Proyeksi low-rank untuk jalur global.
147
+ """
148
+
149
+ def __init__(self, config: GTransformerConfig):
150
+ super().__init__()
151
+ self.config = config
152
+ self.d_model = config.hidden_size
153
+ self.n_heads = config.num_attention_heads
154
+ self.head_dim = self.d_model // self.n_heads
155
+ assert self.d_model % self.n_heads == 0
156
+
157
+ self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model, bias=False)
158
+ self.w_o = nn.Linear(self.d_model, self.d_model, bias=False)
159
+
160
+ self.rotary = RotaryEmbedding(self.head_dim)
161
+
162
+ # Proyeksi low rank global
163
+ self.rank = config.global_rank
164
+ self.Pk = nn.Linear(self.head_dim, self.rank, bias=False)
165
+ self.Pv = nn.Linear(self.head_dim, self.rank, bias=False)
166
+ self.Uo = nn.Linear(self.rank, self.head_dim, bias=False)
167
+
168
+ # Skorer informasi
169
+ self.info_scorer = nn.Sequential(
170
+ nn.Linear(self.d_model, self.d_model // 4, bias=False),
171
+ nn.GELU(),
172
+ nn.Linear(self.d_model // 4, 1, bias=False),
173
+ )
174
+
175
+ self.attn_drop = nn.Dropout(config.attention_dropout)
176
+ self.proj_drop = nn.Dropout(config.hidden_dropout_prob)
177
+
178
+ self.local_window = config.local_window
179
+ self.delta_I_gate = config.delta_I_gate
180
+ self.use_entropy_gate = config.use_entropy_gate
181
+
182
+ def _causal_local_mask(self, T: int, w: int, device) -> torch.Tensor:
183
+ idxs = torch.arange(T, device=device)
184
+ mask = idxs[None, :] - idxs[:, None]
185
+ # izinkan hanya masa lalu dalam jendela lokal
186
+ mask = (mask > 0) | (mask < -(w - 1))
187
+ return mask # True berarti masked
188
+
189
+ def forward(
190
+ self,
191
+ x: torch.Tensor,
192
+ attention_mask: Optional[torch.Tensor] = None,
193
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
194
+ use_cache: bool = False,
195
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
196
+
197
+ B, T, C = x.shape
198
+ H, D = self.n_heads, self.head_dim
199
+
200
+ qkv = self.w_qkv(x) # [B, T, 3C]
201
+ q, k, v = qkv.split(C, dim=-1)
202
+ q = q.view(B, T, H, D).transpose(1, 2) # [B, H, T, D]
203
+ k = k.view(B, T, H, D).transpose(1, 2)
204
+ v = v.view(B, T, H, D).transpose(1, 2)
205
+
206
+ cos, sin = self.rotary(q, T)
207
+ q, k = apply_rotary(q, k, cos, sin)
208
+
209
+ # Tambah cache jika ada
210
+ if past_key_value is not None:
211
+ pk, pv = past_key_value # [B, H, T_past, D]
212
+ k = torch.cat([pk, k], dim=2)
213
+ v = torch.cat([pv, v], dim=2)
214
+ T_total = k.size(2)
215
+ else:
216
+ T_total = T
217
+
218
+ # Atensi lokal
219
+ w = min(self.local_window, T_total)
220
+ scale = 1.0 / math.sqrt(D)
221
+ attn_scores = torch.einsum("bhtd,bhSd->bhtS", q, k) * scale # S = T_total
222
+ # Mask kausal lokal
223
+ local_mask = self._causal_local_mask(T_total, w, x.device) # [T_total, T_total]
224
+ local_mask = local_mask[-T:] # baris untuk query saat ini
225
+ attn_scores = attn_scores.masked_fill(local_mask[None, None, :, :], float("-inf"))
226
+ if attention_mask is not None:
227
+ attn_scores = attn_scores + attention_mask # bentuk harus broadcastable
228
+
229
+ attn_w_local = F.softmax(attn_scores, dim=-1)
230
+ attn_w_local = self.attn_drop(attn_w_local)
231
+ ctx_local = torch.einsum("bhtS,bhSd->bhtd", attn_w_local, v)
232
+
233
+ # Seleksi global berbasis informasi
234
+ # Skor informasi dari representasi x
235
+ with torch.no_grad():
236
+ info_score = self.info_scorer(x).squeeze(-1) # [B, T]
237
+ # skala ke 0..1 via sigmoid
238
+ info_score = torch.sigmoid(info_score)
239
+ if self.use_entropy_gate:
240
+ gate = (info_score > self.delta_I_gate).float() # [B, T]
241
+ else:
242
+ gate = torch.ones_like(info_score)
243
+
244
+ # Proyeksi low rank untuk jalur global hanya pada token bergated
245
+ # Bentuk sederhana: kompres k,v ke rank kecil lalu atensi penuh pada subset
246
+ # Buat mask indeks global per batch
247
+ ctx_global = torch.zeros_like(ctx_local)
248
+ if gate.sum() > 0:
249
+ # kompres k,v
250
+ k_r = self.Pk(k) # [B,H,T_total,R]
251
+ v_r = self.Pv(v) # [B,H,T_total,R]
252
+ q_r = self.Pk(q) # reuse Pk untuk q
253
+
254
+ # gunakan atensi penuh pada subset dengan gate
255
+ # bentuk sederhana, gunakan semua posisi, tapi bobot query di-skala gate query
256
+ gate_q = gate[:, -T:].unsqueeze(1).unsqueeze(-1) # [B,1,T,1]
257
+ attn_scores_g = torch.einsum("bhtr,bhsr->bhts", q_r, k_r) * (scale * D / self.rank)
258
+ attn_w_g = F.softmax(attn_scores_g, dim=-1)
259
+ attn_w_g = self.attn_drop(attn_w_g)
260
+ ctx_g_r = torch.einsum("bhts,bhsr->bhtr", attn_w_g, v_r)
261
+ ctx_g = self.Uo(ctx_g_r) # [B,H,T,D]
262
+ ctx_global = ctx_g * gate_q
263
+
264
+ ctx = ctx_local + ctx_global
265
+ ctx = ctx.transpose(1, 2).contiguous().view(B, T, C)
266
+ out = self.w_o(ctx)
267
+ out = self.proj_drop(out)
268
+
269
+ present = (k, v) if use_cache else None
270
+ return out, present
271
+
272
+
273
+ # ----------------------------
274
+ # Low-Rank FFN
275
+ # ----------------------------
276
+ class LowRankFFN(nn.Module):
277
+ def __init__(self, config: GTransformerConfig):
278
+ super().__init__()
279
+ d = config.hidden_size
280
+ i = config.intermediate_size
281
+ act = build_activation(config.hidden_act)
282
+ self.act = act
283
+ # Faktorisasi: d -> i -> d, dengan bottleneck rank r_ffn
284
+ r_ffn = max(128, i // 8)
285
+ self.w1a = nn.Linear(d, r_ffn, bias=False)
286
+ self.w1b = nn.Linear(d, r_ffn, bias=False)
287
+ self.w2 = nn.Linear(r_ffn, d, bias=False)
288
+ self.drop = nn.Dropout(config.hidden_dropout_prob)
289
+
290
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
291
+ # SWiGLU low-rank
292
+ u = self.w1a(x)
293
+ v = self.w1b(x)
294
+ h = swiglu(torch.cat([u, v], dim=-1))
295
+ out = self.w2(h)
296
+ return self.drop(out)
297
+
298
+
299
+ # ----------------------------
300
+ # MoE Router opsional
301
+ # ----------------------------
302
+ class EntropyMoE(nn.Module):
303
+ def __init__(self, config: GTransformerConfig):
304
+ super().__init__()
305
+ assert config.num_experts > 0
306
+ self.num_experts = config.num_experts
307
+ self.top_k = max(1, config.top_k_experts)
308
+ d = config.hidden_size
309
+ i = config.intermediate_size
310
+
311
+ self.router = nn.Sequential(
312
+ nn.Linear(d, d // 2, bias=False),
313
+ nn.GELU(),
314
+ nn.Linear(d // 2, self.num_experts, bias=False),
315
+ )
316
+ self.experts = nn.ModuleList(
317
+ [nn.Sequential(nn.Linear(d, i), nn.GELU(), nn.Linear(i, d)) for _ in range(self.num_experts)]
318
+ )
319
+
320
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
321
+ B, T, D = x.shape
322
+ logits = self.router(x) # [B,T,E]
323
+ probs = F.softmax(logits, dim=-1)
324
+ topk = torch.topk(probs, k=self.top_k, dim=-1)
325
+ idx = topk.indices # [B,T,K]
326
+ wgt = topk.values # [B,T,K]
327
+
328
+ out = torch.zeros_like(x)
329
+ for k in range(self.top_k):
330
+ sel = idx[..., k] # [B,T]
331
+ # kumpulkan untuk tiap expert
332
+ for e in range(self.num_experts):
333
+ mask = (sel == e).float().unsqueeze(-1) # [B,T,1]
334
+ if mask.sum() == 0:
335
+ continue
336
+ xe = x * mask
337
+ ye = self.experts[e](xe)
338
+ out = out + ye * (wgt[..., k].unsqueeze(-1))
339
+ return out
340
+
341
+
342
+ # ----------------------------
343
+ # Blok Transformer
344
+ # ----------------------------
345
+ class GTransformerBlock(nn.Module):
346
+ def __init__(self, config: GTransformerConfig):
347
+ super().__init__()
348
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
349
+ self.attn = InformationalAttention(config)
350
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
351
+ if config.use_moe and config.num_experts > 0:
352
+ self.ff = EntropyMoE(config)
353
+ else:
354
+ self.ff = LowRankFFN(config) if config.use_low_rank_ffn else nn.Sequential(
355
+ nn.Linear(config.hidden_size, config.intermediate_size),
356
+ nn.GELU(),
357
+ nn.Linear(config.intermediate_size, config.hidden_size),
358
+ )
359
+
360
+ def forward(
361
+ self,
362
+ x: torch.Tensor,
363
+ attention_mask: Optional[torch.Tensor] = None,
364
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
365
+ use_cache: bool = False,
366
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
367
+ h, present = self.attn(self.ln1(x), attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache)
368
+ x = x + h
369
+ x = x + self.ff(self.ln2(x))
370
+ return x, present
371
+
372
+
373
+ # ----------------------------
374
+ # Model dasar
375
+ # ----------------------------
376
+ class GTransformerModel(PreTrainedModel):
377
+ config_class = GTransformerConfig
378
+
379
+ def __init__(self, config: GTransformerConfig):
380
+ super().__init__(config)
381
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
382
+ self.layers = nn.ModuleList([GTransformerBlock(config) for _ in range(config.num_hidden_layers)])
383
+ self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
384
+
385
+ self.gradient_checkpointing = False
386
+
387
+ self.post_init()
388
+
389
+ def forward(
390
+ self,
391
+ input_ids: torch.LongTensor,
392
+ attention_mask: Optional[torch.Tensor] = None,
393
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
394
+ use_cache: Optional[bool] = None,
395
+ **kwargs,
396
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
397
+
398
+ B, T = input_ids.shape
399
+ x = self.embed_tokens(input_ids)
400
+
401
+ new_past = [] if use_cache else None
402
+ for i, layer in enumerate(self.layers):
403
+ pkv = None if past_key_values is None else past_key_values[i]
404
+ x, present = layer(x, attention_mask=attention_mask, past_key_value=pkv, use_cache=use_cache)
405
+ if use_cache:
406
+ new_past.append(present)
407
+
408
+ x = self.ln_f(x)
409
+ return x, new_past
410
+
411
+
412
+ # ----------------------------
413
+ # Causal LM
414
+ # ----------------------------
415
+ class GTransformerForCausalLM(PreTrainedModel):
416
+ config_class = GTransformerConfig
417
+
418
+ def __init__(self, config: GTransformerConfig):
419
+ super().__init__(config)
420
+ self.transformer = GTransformerModel(config)
421
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
422
+ self.post_init()
423
+
424
+ def get_input_embeddings(self):
425
+ return self.transformer.embed_tokens
426
+
427
+ def set_input_embeddings(self, new_embeddings):
428
+ self.transformer.embed_tokens = new_embeddings
429
+
430
+ def tie_weights(self):
431
+ # opsional tidak diikat agar stabil FP8
432
+ pass
433
+
434
+ def forward(
435
+ self,
436
+ input_ids: torch.LongTensor = None,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ labels: Optional[torch.LongTensor] = None,
439
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
440
+ use_cache: Optional[bool] = None,
441
+ **kwargs,
442
+ ) -> CausalLMOutputWithPast:
443
+
444
+ hidden_states, new_past = self.transformer(
445
+ input_ids=input_ids,
446
+ attention_mask=attention_mask,
447
+ past_key_values=past_key_values,
448
+ use_cache=use_cache,
449
+ )
450
+ logits = self.lm_head(hidden_states)
451
+
452
+ loss = None
453
+ if labels is not None:
454
+ shift_logits = logits[:, :-1, :].contiguous()
455
+ shift_labels = labels[:, 1:].contiguous()
456
+ loss = F.cross_entropy(
457
+ shift_logits.view(-1, shift_logits.size(-1)),
458
+ shift_labels.view(-1),
459
+ ignore_index=-100,
460
+ )
461
+
462
+ # Regularisasi informasi sederhana
463
+ if self.config.use_entropy_gate:
464
+ with torch.no_grad():
465
+ probs = F.softmax(shift_logits, dim=-1)
466
+ logp = torch.log(probs + 1e-9)
467
+ H = -(probs * logp).sum(dim=-1).mean()
468
+ # target penurunan entropi moderat
469
+ loss = loss + 1e-4 * H
470
+
471
+ return CausalLMOutputWithPast(
472
+ loss=loss,
473
+ logits=logits,
474
+ past_key_values=new_past,
475
+ hidden_states=None,
476
+ attentions=None,
477
+ )
478
+
479
+ @torch.no_grad()
480
+ def generate_simple(
481
+ self,
482
+ input_ids: torch.LongTensor,
483
+ max_new_tokens: int = 64,
484
+ temperature: float = 1.0,
485
+ ) -> torch.LongTensor:
486
+ self.eval()
487
+ past = None
488
+ out = input_ids
489
+ for _ in range(max_new_tokens):
490
+ logits = self(out[:, -1:].contiguous(), use_cache=True, past_key_values=past).logits
491
+ past = self(out[:, -1:].contiguous(), use_cache=True, past_key_values=past).past_key_values
492
+ next_token = torch.distributions.Categorical(logits=logits[:, -1, :] / max(1e-6, temperature)).sample()
493
+ out = torch.cat([out, next_token.unsqueeze(-1)], dim=1)
494
+ if int(next_token[0].item()) == self.config.eos_token_id:
495
+ break
496
+ return out