idah4 commited on
Commit
9666b4c
·
verified ·
1 Parent(s): c7fa4ec

Upload ETM-Korean (HF inference compatible)

Browse files
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HFETM"
4
+ ],
5
+ "block_size": 512,
6
+ "dtype": "float32",
7
+ "is_decoder": true,
8
+ "model_type": "etm",
9
+ "n_embd": 512,
10
+ "n_head": 16,
11
+ "n_layer": 4,
12
+ "transformers_version": "4.57.1",
13
+ "vocab_size": 30000,
14
+ "auto_map": {
15
+ "AutoModelForCausalLM": "modeling_etm.HFETM",
16
+ "AutoConfig": "modeling_etm.ETMConfig"
17
+ }
18
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf4989f171451a477048c2aaa02f161e2871ec1fe04dd9615f3ef10a3f6f71ae
3
+ size 199653840
modeling_etm.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import torch.nn as nn, torch.nn.functional as F, torch
3
+ import math, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
4
+
5
+ # ---------- 4. 모델 정의 ----------
6
+ # === GeneratingSeries 기반 보조 모듈 ===
7
+ class MomentumEncoder(nn.Module):
8
+ """다항 차분 + 게이트 통합 (길이 보존 100%)"""
9
+ def __init__(self, dim, max_order=3):
10
+ super().__init__()
11
+ self.max_order = max_order
12
+ self.proj = nn.Linear(dim * (max_order + 1), dim)
13
+ self.gate = nn.Linear(dim, dim)
14
+ self.norm = nn.LayerNorm(dim)
15
+
16
+ def forward(self, x):
17
+ # x: (B, T, D)
18
+ B, T, D = x.size()
19
+ diffs = [x]
20
+
21
+ for k in range(1, self.max_order + 1):
22
+ if T <= k:
23
+ # 길이가 너무 짧아서 차분 불가 → 전체 zero pad
24
+ d = torch.zeros(B, T, D, device=x.device, dtype=x.dtype)
25
+ else:
26
+ d_raw = x[:, k:] - x[:, :-k] # (B, T-k, D)
27
+ pad = torch.zeros(B, k, D, device=x.device, dtype=x.dtype)
28
+ d = torch.cat([pad, d_raw], dim=1) # (B, T, D)
29
+
30
+ diffs.append(d)
31
+
32
+ concat = torch.cat(diffs, dim=-1) # (B, T, D*(max_order+1))
33
+ h = self.proj(concat)
34
+ g = torch.sigmoid(self.gate(x))
35
+ return self.norm(h * g + x * (1 - g))
36
+
37
+
38
+
39
+ class GFLayer(nn.Module):
40
+ """Adaptive polynomial generating function"""
41
+ def __init__(self, dim, max_order=6):
42
+ super().__init__()
43
+ self.coeff = nn.Parameter(torch.randn(dim, max_order + 1) * 0.1)
44
+ self.alpha = nn.Parameter(torch.randn(dim) * 0.1)
45
+
46
+ def forward(self, x):
47
+ B, T, D = x.shape
48
+ t = torch.linspace(0, 1, T, device=x.device).view(1, T, 1)
49
+ basis = torch.stack([(t ** k) * torch.exp(-self.alpha.view(1,1,D)*t) for k in range(self.coeff.size(1))], dim=-1)
50
+ gen = torch.einsum("btdk,dk->btd", basis, self.coeff)
51
+ return x + gen
52
+
53
+
54
+ class OrthogonalTemporalProjector(nn.Module):
55
+ """Adaptive rank orthogonal projection"""
56
+ def __init__(self, t_len, dim, rank_ratio=0.25):
57
+ super().__init__()
58
+ rank = max(4, int(rank_ratio * math.sqrt(dim)))
59
+ self.U = nn.Parameter(torch.randn(t_len, rank) / math.sqrt(t_len))
60
+
61
+ def forward(self, x):
62
+ B, T, D = x.shape
63
+ U = F.interpolate(self.U.T.unsqueeze(0), size=T, mode="linear", align_corners=False).squeeze(0).T
64
+ U = F.normalize(U, dim=0)
65
+ P = U @ U.T
66
+ trend = torch.einsum("btd,ts->bsd", x, P)
67
+ resid = x - trend
68
+ return trend + 0.5 * resid
69
+
70
+ class SinusoidalPositionalEncoding(nn.Module):
71
+ def __init__(self, dim, max_len=2048):
72
+ super().__init__()
73
+ pe = torch.zeros(max_len, dim)
74
+ pos = torch.arange(0, max_len).unsqueeze(1)
75
+ div = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
76
+ pe[:, 0::2] = torch.sin(pos * div)
77
+ pe[:, 1::2] = torch.cos(pos * div)
78
+ self.register_buffer("pe", pe.unsqueeze(0))
79
+
80
+ def forward(self, x):
81
+ return x + self.pe[:, :x.size(1)]
82
+
83
+
84
+ # === GPT Block 확장 ===
85
+ class GeneratingBlock(nn.Module):
86
+ """기존 Transformer Block + GeneratingSeries 동역학 통합"""
87
+ def __init__(self, n_embd, n_head, block_size, dropout=0.0, gf_order=2):
88
+ super().__init__()
89
+ self.ln1 = nn.LayerNorm(n_embd)
90
+ self.ln2 = nn.LayerNorm(n_embd)
91
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
92
+ self.mlp = MLP(n_embd, dropout)
93
+ # GeneratingSeries 요소
94
+ self.momentum = MomentumEncoder(n_embd)
95
+ self.gf = GFLayer(n_embd, max_order=gf_order)
96
+ self.otp = OrthogonalTemporalProjector(block_size, n_embd)
97
+
98
+ def forward(self, x):
99
+ # step1: momentum encoding (local diff)
100
+ x = self.momentum(x)
101
+ # step2: attention + residual
102
+ x = x + self.attn(self.ln1(x))
103
+ # step3: generating function expansion in feature domain
104
+ x = self.gf(x)
105
+ # step4: feedforward + residual
106
+ x = x + self.mlp(self.ln2(x))
107
+ # step5: orthogonal trend projection (temporal disentangling)
108
+ x = self.otp(x)
109
+ return x
110
+
111
+ # === CausalSelfAttention과 MLP는 기존과 동일 ===
112
+ class CausalSelfAttention(nn.Module):
113
+ def __init__(self, n_embd, n_head, block_size, dropout=0.0):
114
+ super().__init__()
115
+ assert n_embd % n_head == 0
116
+ self.n_head = n_head
117
+ self.key = nn.Linear(n_embd, n_embd)
118
+ self.query = nn.Linear(n_embd, n_embd)
119
+ self.value = nn.Linear(n_embd, n_embd)
120
+ self.proj = nn.Linear(n_embd, n_embd)
121
+ self.attn_drop = nn.Dropout(dropout)
122
+ self.resid_drop = nn.Dropout(dropout)
123
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size))
124
+
125
+ def forward(self, x):
126
+ B, T, C = x.size()
127
+ k = self.key(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
128
+ q = self.query(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
129
+ v = self.value(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
130
+
131
+ # RMS normalization per head
132
+ q = q / (q.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6)
133
+ k = k / (k.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6)
134
+
135
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
136
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
137
+ att = F.softmax(att, dim=-1)
138
+ att = self.attn_drop(att)
139
+ y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
140
+ return self.resid_drop(self.proj(y))
141
+
142
+ class MLP(nn.Module):
143
+ def __init__(self, n_embd, dropout=0.0):
144
+ super().__init__()
145
+ self.fc = nn.Sequential(
146
+ nn.Linear(n_embd, 4*n_embd),
147
+ nn.GELU(),
148
+ nn.Linear(4*n_embd, n_embd),
149
+ nn.Dropout(dropout),
150
+ )
151
+ def forward(self, x): return self.fc(x)
152
+
153
+ class Block(nn.Module):
154
+ def __init__(self, n_embd, n_head, block_size, dropout=0.0):
155
+ super().__init__()
156
+ self.ln1 = nn.LayerNorm(n_embd)
157
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
158
+ self.ln2 = nn.LayerNorm(n_embd)
159
+ self.mlp = MLP(n_embd, dropout)
160
+ def forward(self, x):
161
+ x = x + self.attn(self.ln1(x))
162
+ x = x + self.mlp(self.ln2(x))
163
+ return x
164
+
165
+ class ByteETM(nn.Module):
166
+ def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout=0.0):
167
+ super().__init__()
168
+ self.token_emb = nn.Embedding(vocab_size, n_embd)
169
+ self.pos_enc = SinusoidalPositionalEncoding(n_embd, max_len=block_size)
170
+ self.drop = nn.Dropout(dropout)
171
+
172
+ self.blocks = nn.ModuleList([
173
+ GeneratingBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)
174
+ ])
175
+ self.ln_f = nn.LayerNorm(n_embd)
176
+ self.head = nn.Linear(n_embd, vocab_size, bias=False)
177
+ self.block_size = block_size
178
+ self.apply(self._init_weights)
179
+
180
+ def _init_weights(self, m):
181
+ if isinstance(m, (nn.Linear, nn.Embedding)):
182
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
183
+ if isinstance(m, nn.Linear) and m.bias is not None:
184
+ nn.init.zeros_(m.bias)
185
+
186
+ def forward(self, idx, targets=None):
187
+ B, T = idx.size()
188
+ assert T <= self.block_size
189
+ x = self.token_emb(idx)
190
+ x = self.pos_enc(x) # ← 여기서 사인·코사인 위치 정보 추가
191
+ x = self.drop(x)
192
+
193
+ for blk in self.blocks:
194
+ x = blk(x)
195
+ x = self.ln_f(x)
196
+ logits = self.head(x)
197
+ loss = None
198
+ if targets is not None:
199
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
200
+ return logits, loss
201
+
202
+ # ====================== ByteLM 최적화 샘플러 ======================
203
+ @staticmethod
204
+ def _sample_next_token(
205
+ logits, # (1, vocab_size)
206
+ prev_tokens, # (1, T)
207
+ temperature: float = 0.7,
208
+ top_k: int | None = 64,
209
+ top_p: float | None = 0.9,
210
+ repetition_penalty: float = 1.1,
211
+ typical_p: float | None = None,
212
+ ):
213
+ """
214
+ Byte-level LM용 logit 후처리 + 샘플링:
215
+ - temperature
216
+ - repetition penalty
217
+ - top-k
218
+ - top-p (nucleus)
219
+ - optional typical sampling
220
+ """
221
+
222
+ # 배치 1 가정 (지금 사용 패턴 기준)
223
+ assert logits.size(0) == 1, "현재 샘플러는 batch=1 사용을 가정한다."
224
+
225
+ # 1) temperature scaling
226
+ logits = logits / max(temperature, 1e-6)
227
+
228
+ # 2) repetition penalty (이전에 나온 토큰들 확률 낮추기)
229
+ if repetition_penalty is not None and repetition_penalty != 1.0:
230
+ unique_tokens = prev_tokens.unique()
231
+ # 단순하게: 이전 토큰들의 logit을 나눠서 확률 감소
232
+ logits[:, unique_tokens] /= repetition_penalty
233
+
234
+ # 3) top-k (상위 k개만 남기기)
235
+ if top_k is not None and top_k > 0 and top_k < logits.size(-1):
236
+ v, _ = torch.topk(logits, top_k)
237
+ logits[logits < v[:, [-1]]] = -float("inf")
238
+
239
+ # 4) 정렬 후 top-p / typical sampling
240
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
241
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
242
+
243
+ # 4-1) typical sampling (선택적)
244
+ if typical_p is not None:
245
+ log_probs = torch.log(sorted_probs + 1e-12)
246
+ entropy = -(sorted_probs * log_probs).sum(-1, keepdim=True)
247
+ # https://arxiv.org/abs/2202.00666 typical sampling 구현
248
+ shifted_kl = torch.cumsum(sorted_probs * (entropy - log_probs), dim=-1)
249
+ typical_mask = shifted_kl > typical_p
250
+ if typical_mask.any():
251
+ first_idx = torch.nonzero(typical_mask[0], as_tuple=False)[0, 0]
252
+ sorted_logits[:, first_idx:] = -float("inf")
253
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
254
+
255
+ # 4-2) nucleus(top-p) sampling
256
+ if top_p is not None and 0.0 < top_p < 1.0:
257
+ cumulative = torch.cumsum(sorted_probs, dim=-1)
258
+ # top_p를 넘는 지점부터 다 날림
259
+ cutoff_mask = cumulative > top_p
260
+ if cutoff_mask.any():
261
+ first_cut = torch.nonzero(cutoff_mask[0], as_tuple=False)[0, 0]
262
+ sorted_logits[:, first_cut:] = -float("inf")
263
+
264
+ # 5) 정렬 이전 인덱스로 복원
265
+ filtered_logits = torch.full_like(logits, -float("inf"))
266
+ filtered_logits.scatter_(1, sorted_idx, sorted_logits)
267
+
268
+ # 6) 최종 확률 분포에서 샘플링
269
+ probs = F.softmax(filtered_logits, dim=-1)
270
+
271
+ # ========= 안정화: 전부 NaN 또는 전부 0인 경우 대응 =========
272
+ if torch.isnan(probs).any() or torch.isinf(probs).any() or probs.sum() == 0:
273
+ # fallback: 원래 logits에서 가장 큰 토큰을 강제로 선택
274
+ next_id = torch.argmax(logits, dim=-1, keepdim=True)
275
+ return next_id
276
+
277
+ next_id = torch.multinomial(probs, num_samples=1)
278
+ return next_id
279
+
280
+
281
+ @torch.no_grad()
282
+ def generate(
283
+ self,
284
+ idx,
285
+ max_new_tokens: int = 200,
286
+ temperature: float = 0.7,
287
+ top_k: int | None = 64,
288
+ top_p: float | None = 0.9,
289
+ repetition_penalty: float = 1.1,
290
+ typical_p: float | None = None,
291
+ eos_token: int | None = None,
292
+ ):
293
+ """
294
+ ByteLM용 고급 generate():
295
+ - temperature, top_k, top_p, repetition_penalty, typical_p 지원
296
+ - eos_token 설정 시 해당 토큰 나오면 조기 종료
297
+ """
298
+ for _ in range(max_new_tokens):
299
+ idx_cond = idx[:, -self.block_size:] # (1, T')
300
+ logits, _ = self(idx_cond) # (1, T', V)
301
+ last_logits = logits[:, -1, :] # (1, V)
302
+
303
+ next_id = self._sample_next_token(
304
+ last_logits,
305
+ prev_tokens=idx,
306
+ temperature=temperature,
307
+ top_k=top_k,
308
+ top_p=top_p,
309
+ repetition_penalty=repetition_penalty,
310
+ typical_p=typical_p,
311
+ ) # (1, 1)
312
+
313
+ idx = torch.cat((idx, next_id), dim=1) # (1, T+1)
314
+
315
+ if eos_token is not None and next_id.item() == eos_token:
316
+ break
317
+
318
+ return idx
319
+
320
+ class ETMConfig(PretrainedConfig):
321
+ model_type = "etm"
322
+ def __init__(self, vocab_size=256, n_embd=512, n_head=8, n_layer=6, block_size=256, **kwargs):
323
+ super().__init__(**kwargs)
324
+ self.vocab_size = vocab_size
325
+ self.n_embd = n_embd
326
+ self.n_head = n_head
327
+ self.n_layer = n_layer
328
+ self.block_size = block_size
329
+
330
+ # 3️⃣ HF 래퍼 클래스
331
+ class HFETM(PreTrainedModel):
332
+ config_class = ETMConfig
333
+ def __init__(self, config):
334
+ super().__init__(config)
335
+ self.model = ByteETM(
336
+ vocab_size=config.vocab_size,
337
+ n_embd=config.n_embd,
338
+ n_head=config.n_head,
339
+ n_layer=config.n_layer,
340
+ block_size=config.block_size,
341
+ )
342
+ def forward(self, input_ids, **kwargs):
343
+ logits, _ = self.model(input_ids)
344
+ return {"logits": logits}
345
+
346
+ def generate(self, *args, **kwargs): # <── 추가
347
+ return self.model.generate(*args, **kwargs)
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|endoftext|>",
4
+ "<|sep|>",
5
+ "<|acc|>",
6
+ "<|tel|>",
7
+ "<|rrn|>"
8
+ ],
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<|unused0|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<|unused1|>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<|endoftext|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<|sep|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "30000": {
36
+ "content": "<|acc|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "30001": {
44
+ "content": "<|tel|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "30002": {
52
+ "content": "<|rrn|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ }
59
+ },
60
+ "additional_special_tokens": [
61
+ "<|endoftext|>",
62
+ "<|sep|>",
63
+ "<|acc|>",
64
+ "<|tel|>",
65
+ "<|rrn|>"
66
+ ],
67
+ "clean_up_tokenization_spaces": false,
68
+ "eos_token": "<|endoftext|>",
69
+ "extra_special_tokens": {},
70
+ "model_max_length": 1000000000000000019884624838656,
71
+ "pad_token": "<|endoftext|>",
72
+ "tokenizer_class": "PreTrainedTokenizerFast"
73
+ }