from transformers import PreTrainedModel, PretrainedConfig import torch.nn as nn, torch.nn.functional as F, torch import math, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F # ---------- 4. 모델 정의 ---------- # === GeneratingSeries 기반 보조 모듈 === class MomentumEncoder(nn.Module): """다항 차분 + 게이트 통합""" def __init__(self, dim, max_order=3): super().__init__() self.max_order = max_order self.proj = nn.Linear(dim * (max_order + 1), dim) self.gate = nn.Linear(dim, dim) self.norm = nn.LayerNorm(dim) def forward(self, x): diffs = [x] for k in range(1, self.max_order + 1): d = F.pad(x[:, k:] - x[:, :-k], (0, 0, k, 0)) diffs.append(d) concat = torch.cat(diffs, dim=-1) h = self.proj(concat) g = torch.sigmoid(self.gate(x)) return self.norm(h * g + x * (1 - g)) class GFLayer(nn.Module): """Adaptive polynomial generating function""" def __init__(self, dim, max_order=6): super().__init__() self.coeff = nn.Parameter(torch.randn(dim, max_order + 1) * 0.1) self.alpha = nn.Parameter(torch.randn(dim) * 0.1) def forward(self, x): B, T, D = x.shape t = torch.linspace(0, 1, T, device=x.device).view(1, T, 1) basis = torch.stack([(t ** k) * torch.exp(-self.alpha.view(1,1,D)*t) for k in range(self.coeff.size(1))], dim=-1) gen = torch.einsum("btdk,dk->btd", basis, self.coeff) return x + gen class OrthogonalTemporalProjector(nn.Module): """Adaptive rank orthogonal projection""" def __init__(self, t_len, dim, rank_ratio=0.25): super().__init__() rank = max(4, int(rank_ratio * math.sqrt(dim))) self.U = nn.Parameter(torch.randn(t_len, rank) / math.sqrt(t_len)) def forward(self, x): B, T, D = x.shape U = F.interpolate(self.U.T.unsqueeze(0), size=T, mode="linear", align_corners=False).squeeze(0).T U = F.normalize(U, dim=0) P = U @ U.T trend = torch.einsum("btd,ts->bsd", x, P) resid = x - trend return trend + 0.5 * resid class SinusoidalPositionalEncoding(nn.Module): def __init__(self, dim, max_len=2048): super().__init__() pe = torch.zeros(max_len, dim) pos = torch.arange(0, max_len).unsqueeze(1) div = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer("pe", pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)] # === GPT Block 확장 === class GeneratingBlock(nn.Module): """기존 Transformer Block + GeneratingSeries 동역학 통합""" def __init__(self, n_embd, n_head, block_size, dropout=0.0, gf_order=2): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) self.mlp = MLP(n_embd, dropout) # GeneratingSeries 요소 self.momentum = MomentumEncoder(n_embd) self.gf = GFLayer(n_embd, max_order=gf_order) self.otp = OrthogonalTemporalProjector(block_size, n_embd) def forward(self, x): # step1: momentum encoding (local diff) x = self.momentum(x) # step2: attention + residual x = x + self.attn(self.ln1(x)) # step3: generating function expansion in feature domain x = self.gf(x) # step4: feedforward + residual x = x + self.mlp(self.ln2(x)) # step5: orthogonal trend projection (temporal disentangling) x = self.otp(x) return x # === CausalSelfAttention과 MLP는 기존과 동일 === class CausalSelfAttention(nn.Module): def __init__(self, n_embd, n_head, block_size, dropout=0.0): super().__init__() assert n_embd % n_head == 0 self.n_head = n_head self.key = nn.Linear(n_embd, n_embd) self.query = nn.Linear(n_embd, n_embd) self.value = nn.Linear(n_embd, n_embd) self.proj = nn.Linear(n_embd, n_embd) self.attn_drop = nn.Dropout(dropout) self.resid_drop = nn.Dropout(dropout) self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size)) def forward(self, x): B, T, C = x.size() k = self.key(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2) q = self.query(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2) v = self.value(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2) # RMS normalization per head q = q / (q.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6) k = k / (k.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6) att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1)) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_drop(att) y = (att @ v).transpose(1, 2).contiguous().view(B, T, C) return self.resid_drop(self.proj(y)) class MLP(nn.Module): def __init__(self, n_embd, dropout=0.0): super().__init__() self.fc = nn.Sequential( nn.Linear(n_embd, 4*n_embd), nn.GELU(), nn.Linear(4*n_embd, n_embd), nn.Dropout(dropout), ) def forward(self, x): return self.fc(x) class Block(nn.Module): def __init__(self, n_embd, n_head, block_size, dropout=0.0): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) self.ln2 = nn.LayerNorm(n_embd) self.mlp = MLP(n_embd, dropout) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class ByteETM(nn.Module): def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout=0.0): super().__init__() self.token_emb = nn.Embedding(vocab_size, n_embd) self.pos_enc = SinusoidalPositionalEncoding(n_embd, max_len=block_size) self.drop = nn.Dropout(dropout) self.blocks = nn.ModuleList([ GeneratingBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer) ]) self.ln_f = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) self.block_size = block_size self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, mean=0.0, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) def forward(self, idx, targets=None): B, T = idx.size() assert T <= self.block_size x = self.token_emb(idx) x = self.pos_enc(x) # ← 여기서 사인·코사인 위치 정보 추가 x = self.drop(x) for blk in self.blocks: x = blk(x) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): for _ in range(max_new_tokens): idx_cond = idx[:, -self.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-8) if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float("inf") probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, next_id), dim=1) return idx class ByteETMConfig(PretrainedConfig): model_type = "byteetm" def __init__(self, vocab_size=258, n_embd=512, n_head=8, n_layer=6, block_size=256, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.n_embd = n_embd self.n_head = n_head self.n_layer = n_layer self.block_size = block_size class HFByteETM(PreTrainedModel): config_class = ByteETMConfig def __init__(self, config): super().__init__(config) self.model = ByteETM( vocab_size=config.vocab_size, n_embd=config.n_embd, n_head=config.n_head, n_layer=config.n_layer, block_size=config.block_size ) def forward(self, input_ids, **kwargs): logits, _ = self.model(input_ids) return {"logits": logits} def generate(self, *args, **kwargs): # <── 추가 return self.model.generate(*args, **kwargs)