idah4 commited on
Commit
300d0ea
·
verified ·
1 Parent(s): ec380c9

Upload ByteETM-Korean (HF inference compatible)

Browse files
Files changed (1) hide show
  1. modeling_byteetm.py +171 -1
modeling_byteetm.py CHANGED
@@ -1,5 +1,176 @@
1
  from transformers import PreTrainedModel, PretrainedConfig
2
  import torch.nn as nn, torch.nn.functional as F, torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class ByteETMConfig(PretrainedConfig):
5
  model_type = "byteetm"
@@ -15,7 +186,6 @@ class HFByteETM(PreTrainedModel):
15
  config_class = ByteETMConfig
16
  def __init__(self, config):
17
  super().__init__(config)
18
- from model import ByteETM # 네가 정의한 실제 모델
19
  self.model = ByteETM(
20
  vocab_size=config.vocab_size,
21
  n_embd=config.n_embd,
 
1
  from transformers import PreTrainedModel, PretrainedConfig
2
  import torch.nn as nn, torch.nn.functional as F, torch
3
+ import math, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
4
+
5
+ # ---------- 4. 모델 정의 ----------
6
+ # === GeneratingSeries 기반 보조 모듈 ===
7
+ class MomentumEncoder(nn.Module):
8
+ """토큰 임베딩 간의 차분을 포함한 동적 인코딩"""
9
+ def __init__(self, dim):
10
+ super().__init__()
11
+ self.linear = nn.Linear(dim * 2, dim)
12
+ self.norm = nn.LayerNorm(dim)
13
+ self.act = nn.Tanh()
14
+ def forward(self, x): # [B,T,C]
15
+ diff = F.pad(x[:, 1:] - x[:, :-1], (0,0,1,0))
16
+ return self.act(self.norm(self.linear(torch.cat([x, diff], dim=-1))))
17
+
18
+ class GFLayer(nn.Module):
19
+ """지수 감쇠 기반의 생성함수 확장"""
20
+ def __init__(self, dim, max_order=6, tau_scale=0.01):
21
+ super().__init__()
22
+ self.coeff = nn.Parameter(torch.randn(dim, max_order + 1) * 0.1)
23
+ self.tau = nn.Parameter(torch.ones(dim) * tau_scale)
24
+ self.max_order = max_order
25
+ def forward(self, x):
26
+ B, T, D = x.shape
27
+ t = torch.arange(T, device=x.device).float().view(1,T,1)
28
+ z = torch.exp(-t * self.tau.view(1,1,D))
29
+ powers = torch.stack([z**k for k in range(self.max_order+1)], dim=-1)
30
+ gen = torch.einsum('btdk,dk->btd', powers, self.coeff)
31
+ return x + gen
32
+
33
+ class OrthogonalTemporalProjector(nn.Module):
34
+ """시퀀스 길이 방향으로 직교 기저 투영"""
35
+ def __init__(self, t_len, rank=8):
36
+ super().__init__()
37
+ self.U = nn.Parameter(torch.randn(t_len, rank) / math.sqrt(t_len))
38
+ def forward(self, x):
39
+ B,T,D = x.shape
40
+ if T != self.U.size(0):
41
+ U = F.interpolate(self.U.T.unsqueeze(0), size=T, mode="linear", align_corners=False).squeeze(0).T
42
+ else:
43
+ U = self.U
44
+ U = F.normalize(U, dim=0)
45
+ P = U @ U.T
46
+ trend = torch.einsum('btd,ts->bsd', x, P)
47
+ resid = x - trend
48
+ return 0.5*(trend + resid)
49
+
50
+ # === GPT Block 확장 ===
51
+ class GeneratingBlock(nn.Module):
52
+ """기존 Transformer Block + GeneratingSeries 동역학 통합"""
53
+ def __init__(self, n_embd, n_head, block_size, dropout=0.0, gf_order=6):
54
+ super().__init__()
55
+ self.ln1 = nn.LayerNorm(n_embd)
56
+ self.ln2 = nn.LayerNorm(n_embd)
57
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
58
+ self.mlp = MLP(n_embd, dropout)
59
+ # GeneratingSeries 요소
60
+ self.momentum = MomentumEncoder(n_embd)
61
+ self.gf = GFLayer(n_embd, max_order=gf_order)
62
+ self.otp = OrthogonalTemporalProjector(block_size, rank=min(8, block_size//4))
63
+ def forward(self, x):
64
+ # step1: momentum encoding (local diff)
65
+ x = self.momentum(x)
66
+ # step2: attention + residual
67
+ x = x + self.attn(self.ln1(x))
68
+ # step3: generating function expansion in feature domain
69
+ x = self.gf(x)
70
+ # step4: feedforward + residual
71
+ x = x + self.mlp(self.ln2(x))
72
+ # step5: orthogonal trend projection (temporal disentangling)
73
+ x = self.otp(x)
74
+ return x
75
+
76
+ # === CausalSelfAttention과 MLP는 기존과 동일 ===
77
+ class CausalSelfAttention(nn.Module):
78
+ def __init__(self, n_embd, n_head, block_size, dropout=0.0):
79
+ super().__init__()
80
+ assert n_embd % n_head == 0
81
+ self.n_head = n_head
82
+ self.key = nn.Linear(n_embd, n_embd)
83
+ self.query = nn.Linear(n_embd, n_embd)
84
+ self.value = nn.Linear(n_embd, n_embd)
85
+ self.proj = nn.Linear(n_embd, n_embd)
86
+ self.attn_drop = nn.Dropout(dropout)
87
+ self.resid_drop = nn.Dropout(dropout)
88
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size))
89
+
90
+ def forward(self, x):
91
+ B, T, C = x.size()
92
+ k = self.key(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
93
+ q = self.query(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
94
+ v = self.value(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
95
+ att = (q @ k.transpose(-2,-1)) / math.sqrt(k.size(-1))
96
+ att = att.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf"))
97
+ att = F.softmax(att, dim=-1)
98
+ att = self.attn_drop(att)
99
+ y = att @ v
100
+ y = y.transpose(1,2).contiguous().view(B,T,C)
101
+ y = self.resid_drop(self.proj(y))
102
+ return y
103
+
104
+ class MLP(nn.Module):
105
+ def __init__(self, n_embd, dropout=0.0):
106
+ super().__init__()
107
+ self.fc = nn.Sequential(
108
+ nn.Linear(n_embd, 4*n_embd),
109
+ nn.GELU(),
110
+ nn.Linear(4*n_embd, n_embd),
111
+ nn.Dropout(dropout),
112
+ )
113
+ def forward(self, x): return self.fc(x)
114
+
115
+ class Block(nn.Module):
116
+ def __init__(self, n_embd, n_head, block_size, dropout=0.0):
117
+ super().__init__()
118
+ self.ln1 = nn.LayerNorm(n_embd)
119
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
120
+ self.ln2 = nn.LayerNorm(n_embd)
121
+ self.mlp = MLP(n_embd, dropout)
122
+ def forward(self, x):
123
+ x = x + self.attn(self.ln1(x))
124
+ x = x + self.mlp(self.ln2(x))
125
+ return x
126
+
127
+ class ByteETM(nn.Module):
128
+ def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout=0.0):
129
+ super().__init__()
130
+ self.token_emb = nn.Embedding(vocab_size, n_embd)
131
+ self.pos_emb = nn.Embedding(block_size, n_embd)
132
+ self.drop = nn.Dropout(dropout)
133
+ # self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
134
+ self.blocks = nn.ModuleList([GeneratingBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
135
+ self.ln_f = nn.LayerNorm(n_embd)
136
+ self.head = nn.Linear(n_embd, vocab_size, bias=False)
137
+ self.block_size = block_size
138
+ self.apply(self._init_weights)
139
+
140
+ def _init_weights(self, m):
141
+ if isinstance(m, (nn.Linear, nn.Embedding)):
142
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
143
+ if isinstance(m, nn.Linear) and m.bias is not None:
144
+ nn.init.zeros_(m.bias)
145
+
146
+ def forward(self, idx, targets=None):
147
+ B, T = idx.size()
148
+ assert T <= self.block_size
149
+ pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
150
+ x = self.token_emb(idx) + self.pos_emb(pos)
151
+ x = self.drop(x)
152
+ for blk in self.blocks:
153
+ x = blk(x)
154
+ x = self.ln_f(x)
155
+ logits = self.head(x) # (B,T,V)
156
+ loss = None
157
+ if targets is not None:
158
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
159
+ return logits, loss
160
+
161
+ @torch.no_grad()
162
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
163
+ for _ in range(max_new_tokens):
164
+ idx_cond = idx[:, -self.block_size:]
165
+ logits, _ = self(idx_cond)
166
+ logits = logits[:, -1, :] / max(temperature, 1e-8)
167
+ if top_k is not None:
168
+ v, _ = torch.topk(logits, top_k)
169
+ logits[logits < v[:, [-1]]] = -float("inf")
170
+ probs = F.softmax(logits, dim=-1)
171
+ next_id = torch.multinomial(probs, num_samples=1)
172
+ idx = torch.cat((idx, next_id), dim=1)
173
+ return idx
174
 
175
  class ByteETMConfig(PretrainedConfig):
176
  model_type = "byteetm"
 
186
  config_class = ByteETMConfig
187
  def __init__(self, config):
188
  super().__init__(config)
 
189
  self.model = ByteETM(
190
  vocab_size=config.vocab_size,
191
  n_embd=config.n_embd,