File size: 8,933 Bytes
00205bb
 
300d0ea
 
 
 
 
349ff93
 
300d0ea
349ff93
 
 
300d0ea
349ff93
 
 
 
 
 
 
 
 
 
 
300d0ea
 
349ff93
 
300d0ea
 
349ff93
 
300d0ea
 
349ff93
 
 
300d0ea
 
349ff93
300d0ea
349ff93
 
300d0ea
349ff93
300d0ea
349ff93
300d0ea
349ff93
 
300d0ea
 
349ff93
300d0ea
349ff93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300d0ea
 
 
 
349ff93
300d0ea
 
 
 
 
 
 
 
349ff93
 
300d0ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349ff93
 
 
 
 
 
 
300d0ea
 
349ff93
 
300d0ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349ff93
300d0ea
349ff93
 
 
 
300d0ea
 
 
 
 
 
 
 
 
 
 
 
 
 
349ff93
 
300d0ea
349ff93
300d0ea
 
 
349ff93
300d0ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00205bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b241bc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from transformers import PreTrainedModel, PretrainedConfig
import torch.nn as nn, torch.nn.functional as F, torch
import math, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F

# ---------- 4. 모델 정의 ----------
# === GeneratingSeries 기반 보조 모듈 ===
class MomentumEncoder(nn.Module):
    """다항 차분 + 게이트 통합"""
    def __init__(self, dim, max_order=3):
        super().__init__()
        self.max_order = max_order
        self.proj = nn.Linear(dim * (max_order + 1), dim)
        self.gate = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        diffs = [x]
        for k in range(1, self.max_order + 1):
            d = F.pad(x[:, k:] - x[:, :-k], (0, 0, k, 0))
            diffs.append(d)
        concat = torch.cat(diffs, dim=-1)
        h = self.proj(concat)
        g = torch.sigmoid(self.gate(x))
        return self.norm(h * g + x * (1 - g))


class GFLayer(nn.Module):
    """Adaptive polynomial generating function"""
    def __init__(self, dim, max_order=6):
        super().__init__()
        self.coeff = nn.Parameter(torch.randn(dim, max_order + 1) * 0.1)
        self.alpha = nn.Parameter(torch.randn(dim) * 0.1)

    def forward(self, x):
        B, T, D = x.shape
        t = torch.linspace(0, 1, T, device=x.device).view(1, T, 1)
        basis = torch.stack([(t ** k) * torch.exp(-self.alpha.view(1,1,D)*t) for k in range(self.coeff.size(1))], dim=-1)
        gen = torch.einsum("btdk,dk->btd", basis, self.coeff)
        return x + gen


class OrthogonalTemporalProjector(nn.Module):
    """Adaptive rank orthogonal projection"""
    def __init__(self, t_len, dim, rank_ratio=0.25):
        super().__init__()
        rank = max(4, int(rank_ratio * math.sqrt(dim)))
        self.U = nn.Parameter(torch.randn(t_len, rank) / math.sqrt(t_len))

    def forward(self, x):
        B, T, D = x.shape
        U = F.interpolate(self.U.T.unsqueeze(0), size=T, mode="linear", align_corners=False).squeeze(0).T
        U = F.normalize(U, dim=0)
        P = U @ U.T
        trend = torch.einsum("btd,ts->bsd", x, P)
        resid = x - trend
        return trend + 0.5 * resid

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=2048):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


# === GPT Block 확장 ===
class GeneratingBlock(nn.Module):
    """기존 Transformer Block + GeneratingSeries 동역학 통합"""
    def __init__(self, n_embd, n_head, block_size, dropout=0.0, gf_order=2):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
        self.mlp = MLP(n_embd, dropout)
        # GeneratingSeries 요소
        self.momentum = MomentumEncoder(n_embd)
        self.gf = GFLayer(n_embd, max_order=gf_order)
        self.otp = OrthogonalTemporalProjector(block_size, n_embd)
        
    def forward(self, x):
        # step1: momentum encoding (local diff)
        x = self.momentum(x)
        # step2: attention + residual
        x = x + self.attn(self.ln1(x))
        # step3: generating function expansion in feature domain
        x = self.gf(x)
        # step4: feedforward + residual
        x = x + self.mlp(self.ln2(x))
        # step5: orthogonal trend projection (temporal disentangling)
        x = self.otp(x)
        return x

# === CausalSelfAttention과 MLP는 기존과 동일 ===
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout=0.0):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.proj = nn.Linear(n_embd, n_embd)
        self.attn_drop = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size))

    def forward(self, x):
        B, T, C = x.size()
        k = self.key(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
        q = self.query(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
        v = self.value(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)

        # RMS normalization per head
        q = q / (q.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6)
        k = k / (k.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.resid_drop(self.proj(y))

class MLP(nn.Module):
    def __init__(self, n_embd, dropout=0.0):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.GELU(),
            nn.Linear(4*n_embd, n_embd),
            nn.Dropout(dropout),
        )
    def forward(self, x): return self.fc(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = MLP(n_embd, dropout)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class ByteETM(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout=0.0):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_enc   = SinusoidalPositionalEncoding(n_embd, max_len=block_size)
        self.drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            GeneratingBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
        self.block_size = block_size
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.zeros_(m.bias)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.block_size
        x = self.token_emb(idx)
        x = self.pos_enc(x)          # ← 여기서 사인·코사인 위치 정보 추가
        x = self.drop(x)

        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-8)
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float("inf")
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_id), dim=1)
        return idx

class ByteETMConfig(PretrainedConfig):
    model_type = "byteetm"
    def __init__(self, vocab_size=258, n_embd=512, n_head=8, n_layer=6, block_size=256, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.n_head = n_head
        self.n_layer = n_layer
        self.block_size = block_size

class HFByteETM(PreTrainedModel):
    config_class = ByteETMConfig
    def __init__(self, config):
        super().__init__(config)
        self.model = ByteETM(
            vocab_size=config.vocab_size,
            n_embd=config.n_embd,
            n_head=config.n_head,
            n_layer=config.n_layer,
            block_size=config.block_size
        )
    def forward(self, input_ids, **kwargs):
        logits, _ = self.model(input_ids)
        return {"logits": logits}
    
    def generate(self, *args, **kwargs):   # <── 추가
        return self.model.generate(*args, **kwargs)