|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
import torch.nn as nn, torch.nn.functional as F, torch |
|
|
import math, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
class MomentumEncoder(nn.Module): |
|
|
"""다항 차분 + 게이트 통합""" |
|
|
def __init__(self, dim, max_order=3): |
|
|
super().__init__() |
|
|
self.max_order = max_order |
|
|
self.proj = nn.Linear(dim * (max_order + 1), dim) |
|
|
self.gate = nn.Linear(dim, dim) |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
|
|
def forward(self, x): |
|
|
diffs = [x] |
|
|
for k in range(1, self.max_order + 1): |
|
|
d = F.pad(x[:, k:] - x[:, :-k], (0, 0, k, 0)) |
|
|
diffs.append(d) |
|
|
concat = torch.cat(diffs, dim=-1) |
|
|
h = self.proj(concat) |
|
|
g = torch.sigmoid(self.gate(x)) |
|
|
return self.norm(h * g + x * (1 - g)) |
|
|
|
|
|
|
|
|
class GFLayer(nn.Module): |
|
|
"""Adaptive polynomial generating function""" |
|
|
def __init__(self, dim, max_order=6): |
|
|
super().__init__() |
|
|
self.coeff = nn.Parameter(torch.randn(dim, max_order + 1) * 0.1) |
|
|
self.alpha = nn.Parameter(torch.randn(dim) * 0.1) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, D = x.shape |
|
|
t = torch.linspace(0, 1, T, device=x.device).view(1, T, 1) |
|
|
basis = torch.stack([(t ** k) * torch.exp(-self.alpha.view(1,1,D)*t) for k in range(self.coeff.size(1))], dim=-1) |
|
|
gen = torch.einsum("btdk,dk->btd", basis, self.coeff) |
|
|
return x + gen |
|
|
|
|
|
|
|
|
class OrthogonalTemporalProjector(nn.Module): |
|
|
"""Adaptive rank orthogonal projection""" |
|
|
def __init__(self, t_len, dim, rank_ratio=0.25): |
|
|
super().__init__() |
|
|
rank = max(4, int(rank_ratio * math.sqrt(dim))) |
|
|
self.U = nn.Parameter(torch.randn(t_len, rank) / math.sqrt(t_len)) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, D = x.shape |
|
|
U = F.interpolate(self.U.T.unsqueeze(0), size=T, mode="linear", align_corners=False).squeeze(0).T |
|
|
U = F.normalize(U, dim=0) |
|
|
P = U @ U.T |
|
|
trend = torch.einsum("btd,ts->bsd", x, P) |
|
|
resid = x - trend |
|
|
return trend + 0.5 * resid |
|
|
|
|
|
class SinusoidalPositionalEncoding(nn.Module): |
|
|
def __init__(self, dim, max_len=2048): |
|
|
super().__init__() |
|
|
pe = torch.zeros(max_len, dim) |
|
|
pos = torch.arange(0, max_len).unsqueeze(1) |
|
|
div = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim)) |
|
|
pe[:, 0::2] = torch.sin(pos * div) |
|
|
pe[:, 1::2] = torch.cos(pos * div) |
|
|
self.register_buffer("pe", pe.unsqueeze(0)) |
|
|
|
|
|
def forward(self, x): |
|
|
return x + self.pe[:, :x.size(1)] |
|
|
|
|
|
|
|
|
|
|
|
class GeneratingBlock(nn.Module): |
|
|
"""기존 Transformer Block + GeneratingSeries 동역학 통합""" |
|
|
def __init__(self, n_embd, n_head, block_size, dropout=0.0, gf_order=2): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(n_embd) |
|
|
self.ln2 = nn.LayerNorm(n_embd) |
|
|
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) |
|
|
self.mlp = MLP(n_embd, dropout) |
|
|
|
|
|
self.momentum = MomentumEncoder(n_embd) |
|
|
self.gf = GFLayer(n_embd, max_order=gf_order) |
|
|
self.otp = OrthogonalTemporalProjector(block_size, n_embd) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.momentum(x) |
|
|
|
|
|
x = x + self.attn(self.ln1(x)) |
|
|
|
|
|
x = self.gf(x) |
|
|
|
|
|
x = x + self.mlp(self.ln2(x)) |
|
|
|
|
|
x = self.otp(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
def __init__(self, n_embd, n_head, block_size, dropout=0.0): |
|
|
super().__init__() |
|
|
assert n_embd % n_head == 0 |
|
|
self.n_head = n_head |
|
|
self.key = nn.Linear(n_embd, n_embd) |
|
|
self.query = nn.Linear(n_embd, n_embd) |
|
|
self.value = nn.Linear(n_embd, n_embd) |
|
|
self.proj = nn.Linear(n_embd, n_embd) |
|
|
self.attn_drop = nn.Dropout(dropout) |
|
|
self.resid_drop = nn.Dropout(dropout) |
|
|
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size)) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
k = self.key(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2) |
|
|
q = self.query(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2) |
|
|
v = self.value(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2) |
|
|
|
|
|
|
|
|
q = q / (q.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6) |
|
|
k = k / (k.pow(2).mean(-1, keepdim=True).sqrt() + 1e-6) |
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1)) |
|
|
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) |
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.attn_drop(att) |
|
|
y = (att @ v).transpose(1, 2).contiguous().view(B, T, C) |
|
|
return self.resid_drop(self.proj(y)) |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, n_embd, dropout=0.0): |
|
|
super().__init__() |
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(n_embd, 4*n_embd), |
|
|
nn.GELU(), |
|
|
nn.Linear(4*n_embd, n_embd), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
def forward(self, x): return self.fc(x) |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, n_embd, n_head, block_size, dropout=0.0): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(n_embd) |
|
|
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout) |
|
|
self.ln2 = nn.LayerNorm(n_embd) |
|
|
self.mlp = MLP(n_embd, dropout) |
|
|
def forward(self, x): |
|
|
x = x + self.attn(self.ln1(x)) |
|
|
x = x + self.mlp(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
class ByteETM(nn.Module): |
|
|
def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout=0.0): |
|
|
super().__init__() |
|
|
self.token_emb = nn.Embedding(vocab_size, n_embd) |
|
|
self.pos_enc = SinusoidalPositionalEncoding(n_embd, max_len=block_size) |
|
|
self.drop = nn.Dropout(dropout) |
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
GeneratingBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer) |
|
|
]) |
|
|
self.ln_f = nn.LayerNorm(n_embd) |
|
|
self.head = nn.Linear(n_embd, vocab_size, bias=False) |
|
|
self.block_size = block_size |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.Linear, nn.Embedding)): |
|
|
nn.init.normal_(m.weight, mean=0.0, std=0.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def forward(self, idx, targets=None): |
|
|
B, T = idx.size() |
|
|
assert T <= self.block_size |
|
|
x = self.token_emb(idx) |
|
|
x = self.pos_enc(x) |
|
|
x = self.drop(x) |
|
|
|
|
|
for blk in self.blocks: |
|
|
x = blk(x) |
|
|
x = self.ln_f(x) |
|
|
logits = self.head(x) |
|
|
loss = None |
|
|
if targets is not None: |
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) |
|
|
return logits, loss |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
|
for _ in range(max_new_tokens): |
|
|
idx_cond = idx[:, -self.block_size:] |
|
|
logits, _ = self(idx_cond) |
|
|
logits = logits[:, -1, :] / max(temperature, 1e-8) |
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, top_k) |
|
|
logits[logits < v[:, [-1]]] = -float("inf") |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_id = torch.multinomial(probs, num_samples=1) |
|
|
idx = torch.cat((idx, next_id), dim=1) |
|
|
return idx |
|
|
|
|
|
class ByteETMConfig(PretrainedConfig): |
|
|
model_type = "byteetm" |
|
|
def __init__(self, vocab_size=258, n_embd=512, n_head=8, n_layer=6, block_size=256, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.n_embd = n_embd |
|
|
self.n_head = n_head |
|
|
self.n_layer = n_layer |
|
|
self.block_size = block_size |
|
|
|
|
|
class HFByteETM(PreTrainedModel): |
|
|
config_class = ByteETMConfig |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = ByteETM( |
|
|
vocab_size=config.vocab_size, |
|
|
n_embd=config.n_embd, |
|
|
n_head=config.n_head, |
|
|
n_layer=config.n_layer, |
|
|
block_size=config.block_size |
|
|
) |
|
|
def forward(self, input_ids, **kwargs): |
|
|
logits, _ = self.model(input_ids) |
|
|
return {"logits": logits} |
|
|
|
|
|
def generate(self, *args, **kwargs): |
|
|
return self.model.generate(*args, **kwargs) |
|
|
|