idah4 commited on
Commit
349ff93
·
verified ·
1 Parent(s): 4099e14

Upload ByteETM-Korean (HF inference compatible)

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. model.safetensors +2 -2
  3. modeling_byteetm.py +67 -38
config.json CHANGED
@@ -6,7 +6,7 @@
6
  "dtype": "float32",
7
  "model_type": "byteetm",
8
  "n_embd": 512,
9
- "n_head": 8,
10
  "n_layer": 4,
11
  "transformers_version": "4.57.1",
12
  "vocab_size": 258,
 
6
  "dtype": "float32",
7
  "model_type": "byteetm",
8
  "n_embd": 512,
9
+ "n_head": 16,
10
  "n_layer": 4,
11
  "transformers_version": "4.57.1",
12
  "vocab_size": 258,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9ef2d672b4c0e5818c0cd68e45cdc879df4406dbb3520715dc7097b17e8d9f19
3
- size 65296016
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9f6a671faac0a301ebebedfbf2bcaa3457ec013cff8059bc5efdc72eab66274
3
+ size 77830592
modeling_byteetm.py CHANGED
@@ -5,52 +5,74 @@ import math, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as
5
  # ---------- 4. 모델 정의 ----------
6
  # === GeneratingSeries 기반 보조 모듈 ===
7
  class MomentumEncoder(nn.Module):
8
- """토큰 임베딩 간의 차분을 포함한 동적 인코딩"""
9
- def __init__(self, dim):
10
  super().__init__()
11
- self.linear = nn.Linear(dim * 2, dim)
 
 
12
  self.norm = nn.LayerNorm(dim)
13
- self.act = nn.Tanh()
14
- def forward(self, x): # [B,T,C]
15
- diff = F.pad(x[:, 1:] - x[:, :-1], (0,0,1,0))
16
- return self.act(self.norm(self.linear(torch.cat([x, diff], dim=-1))))
 
 
 
 
 
 
 
17
 
18
  class GFLayer(nn.Module):
19
- """지수 감쇠 기반의 생성함수 확장"""
20
- def __init__(self, dim, max_order=6, tau_scale=0.01):
21
  super().__init__()
22
  self.coeff = nn.Parameter(torch.randn(dim, max_order + 1) * 0.1)
23
- self.tau = nn.Parameter(torch.ones(dim) * tau_scale)
24
- self.max_order = max_order
25
  def forward(self, x):
26
  B, T, D = x.shape
27
- t = torch.arange(T, device=x.device).float().view(1,T,1)
28
- z = torch.exp(-t * self.tau.view(1,1,D))
29
- powers = torch.stack([z**k for k in range(self.max_order+1)], dim=-1)
30
- gen = torch.einsum('btdk,dk->btd', powers, self.coeff)
31
  return x + gen
32
 
 
33
  class OrthogonalTemporalProjector(nn.Module):
34
- """시퀀스 길이 방향으로 직교 기저 투영"""
35
- def __init__(self, t_len, rank=8):
36
  super().__init__()
 
37
  self.U = nn.Parameter(torch.randn(t_len, rank) / math.sqrt(t_len))
 
38
  def forward(self, x):
39
- B,T,D = x.shape
40
- if T != self.U.size(0):
41
- U = F.interpolate(self.U.T.unsqueeze(0), size=T, mode="linear", align_corners=False).squeeze(0).T
42
- else:
43
- U = self.U
44
  U = F.normalize(U, dim=0)
45
  P = U @ U.T
46
- trend = torch.einsum('btd,ts->bsd', x, P)
47
  resid = x - trend
48
- return 0.5*(trend + resid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # === GPT Block 확장 ===
51
  class GeneratingBlock(nn.Module):
52
  """기존 Transformer Block + GeneratingSeries 동역학 통합"""
53
- def __init__(self, n_embd, n_head, block_size, dropout=0.0, gf_order=6):
54
  super().__init__()
55
  self.ln1 = nn.LayerNorm(n_embd)
56
  self.ln2 = nn.LayerNorm(n_embd)
@@ -59,7 +81,8 @@ class GeneratingBlock(nn.Module):
59
  # GeneratingSeries 요소
60
  self.momentum = MomentumEncoder(n_embd)
61
  self.gf = GFLayer(n_embd, max_order=gf_order)
62
- self.otp = OrthogonalTemporalProjector(block_size, rank=min(8, block_size//4))
 
63
  def forward(self, x):
64
  # step1: momentum encoding (local diff)
65
  x = self.momentum(x)
@@ -92,14 +115,17 @@ class CausalSelfAttention(nn.Module):
92
  k = self.key(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
93
  q = self.query(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
94
  v = self.value(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
95
- att = (q @ k.transpose(-2,-1)) / math.sqrt(k.size(-1))
96
- att = att.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf"))
 
 
 
 
 
97
  att = F.softmax(att, dim=-1)
98
  att = self.attn_drop(att)
99
- y = att @ v
100
- y = y.transpose(1,2).contiguous().view(B,T,C)
101
- y = self.resid_drop(self.proj(y))
102
- return y
103
 
104
  class MLP(nn.Module):
105
  def __init__(self, n_embd, dropout=0.0):
@@ -128,10 +154,12 @@ class ByteETM(nn.Module):
128
  def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout=0.0):
129
  super().__init__()
130
  self.token_emb = nn.Embedding(vocab_size, n_embd)
131
- self.pos_emb = nn.Embedding(block_size, n_embd)
132
  self.drop = nn.Dropout(dropout)
133
- # self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
134
- self.blocks = nn.ModuleList([GeneratingBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
 
 
135
  self.ln_f = nn.LayerNorm(n_embd)
136
  self.head = nn.Linear(n_embd, vocab_size, bias=False)
137
  self.block_size = block_size
@@ -146,13 +174,14 @@ class ByteETM(nn.Module):
146
  def forward(self, idx, targets=None):
147
  B, T = idx.size()
148
  assert T <= self.block_size
149
- pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
150
- x = self.token_emb(idx) + self.pos_emb(pos)
151
  x = self.drop(x)
 
152
  for blk in self.blocks:
153
  x = blk(x)
154
  x = self.ln_f(x)
155
- logits = self.head(x) # (B,T,V)
156
  loss = None
157
  if targets is not None:
158
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
 
5
  # ---------- 4. 모델 정의 ----------
6
  # === GeneratingSeries 기반 보조 모듈 ===
7
  class MomentumEncoder(nn.Module):
8
+ """다항 차분 + 게이트 통합"""
9
+ def __init__(self, dim, max_order=3):
10
  super().__init__()
11
+ self.max_order = max_order
12
+ self.proj = nn.Linear(dim * (max_order + 1), dim)
13
+ self.gate = nn.Linear(dim, dim)
14
  self.norm = nn.LayerNorm(dim)
15
+
16
+ def forward(self, x):
17
+ diffs = [x]
18
+ for k in range(1, self.max_order + 1):
19
+ d = F.pad(x[:, k:] - x[:, :-k], (0, 0, k, 0))
20
+ diffs.append(d)
21
+ concat = torch.cat(diffs, dim=-1)
22
+ h = self.proj(concat)
23
+ g = torch.sigmoid(self.gate(x))
24
+ return self.norm(h * g + x * (1 - g))
25
+
26
 
27
  class GFLayer(nn.Module):
28
+ """Adaptive polynomial generating function"""
29
+ def __init__(self, dim, max_order=6):
30
  super().__init__()
31
  self.coeff = nn.Parameter(torch.randn(dim, max_order + 1) * 0.1)
32
+ self.alpha = nn.Parameter(torch.randn(dim) * 0.1)
33
+
34
  def forward(self, x):
35
  B, T, D = x.shape
36
+ t = torch.linspace(0, 1, T, device=x.device).view(1, T, 1)
37
+ basis = torch.stack([(t ** k) * torch.exp(-self.alpha.view(1,1,D)*t) for k in range(self.coeff.size(1))], dim=-1)
38
+ gen = torch.einsum("btdk,dk->btd", basis, self.coeff)
 
39
  return x + gen
40
 
41
+
42
  class OrthogonalTemporalProjector(nn.Module):
43
+ """Adaptive rank orthogonal projection"""
44
+ def __init__(self, t_len, dim, rank_ratio=0.25):
45
  super().__init__()
46
+ rank = max(4, int(rank_ratio * math.sqrt(dim)))
47
  self.U = nn.Parameter(torch.randn(t_len, rank) / math.sqrt(t_len))
48
+
49
  def forward(self, x):
50
+ B, T, D = x.shape
51
+ U = F.interpolate(self.U.T.unsqueeze(0), size=T, mode="linear", align_corners=False).squeeze(0).T
 
 
 
52
  U = F.normalize(U, dim=0)
53
  P = U @ U.T
54
+ trend = torch.einsum("btd,ts->bsd", x, P)
55
  resid = x - trend
56
+ return trend + 0.5 * resid
57
+
58
+ class SinusoidalPositionalEncoding(nn.Module):
59
+ def __init__(self, dim, max_len=2048):
60
+ super().__init__()
61
+ pe = torch.zeros(max_len, dim)
62
+ pos = torch.arange(0, max_len).unsqueeze(1)
63
+ div = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
64
+ pe[:, 0::2] = torch.sin(pos * div)
65
+ pe[:, 1::2] = torch.cos(pos * div)
66
+ self.register_buffer("pe", pe.unsqueeze(0))
67
+
68
+ def forward(self, x):
69
+ return x + self.pe[:, :x.size(1)]
70
+
71
 
72
  # === GPT Block 확장 ===
73
  class GeneratingBlock(nn.Module):
74
  """기존 Transformer Block + GeneratingSeries 동역학 통합"""
75
+ def __init__(self, n_embd, n_head, block_size, dropout=0.0, gf_order=2):
76
  super().__init__()
77
  self.ln1 = nn.LayerNorm(n_embd)
78
  self.ln2 = nn.LayerNorm(n_embd)
 
81
  # GeneratingSeries 요소
82
  self.momentum = MomentumEncoder(n_embd)
83
  self.gf = GFLayer(n_embd, max_order=gf_order)
84
+ self.otp = OrthogonalTemporalProjector(block_size, n_embd)
85
+
86
  def forward(self, x):
87
  # step1: momentum encoding (local diff)
88
  x = self.momentum(x)
 
115
  k = self.key(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
116
  q = self.query(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
117
  v = self.value(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
118
+
119
+ # RMS normalization per head
120
+ q = q / (q.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6)
121
+ k = k / (k.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6)
122
+
123
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
124
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
125
  att = F.softmax(att, dim=-1)
126
  att = self.attn_drop(att)
127
+ y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
128
+ return self.resid_drop(self.proj(y))
 
 
129
 
130
  class MLP(nn.Module):
131
  def __init__(self, n_embd, dropout=0.0):
 
154
  def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout=0.0):
155
  super().__init__()
156
  self.token_emb = nn.Embedding(vocab_size, n_embd)
157
+ self.pos_enc = SinusoidalPositionalEncoding(n_embd, max_len=block_size)
158
  self.drop = nn.Dropout(dropout)
159
+
160
+ self.blocks = nn.ModuleList([
161
+ GeneratingBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)
162
+ ])
163
  self.ln_f = nn.LayerNorm(n_embd)
164
  self.head = nn.Linear(n_embd, vocab_size, bias=False)
165
  self.block_size = block_size
 
174
  def forward(self, idx, targets=None):
175
  B, T = idx.size()
176
  assert T <= self.block_size
177
+ x = self.token_emb(idx)
178
+ x = self.pos_enc(x) # 여기서 사인·코사인 위치 정보 추가
179
  x = self.drop(x)
180
+
181
  for blk in self.blocks:
182
  x = blk(x)
183
  x = self.ln_f(x)
184
+ logits = self.head(x)
185
  loss = None
186
  if targets is not None:
187
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))