SOVYN-85M / model.py
SOVYN's picture
Upload folder using huggingface_hub
ffeff6b verified
"""
SOVYN-85M 모델 아키텍처
https://huggingface.co/SOVYN/SOVYN-85M
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ModelConfig:
vocab_size: int = 16384
context_length: int = 512
embed_dim: int = 768
num_heads: int = 12
num_layers: int = 12
dropout: float = 0.1
bias: bool = False
class CausalSelfAttention(nn.Module):
def __init__(self, cfg):
super().__init__()
self.num_heads = cfg.num_heads
self.head_dim = cfg.embed_dim // cfg.num_heads
self.embed_dim = cfg.embed_dim
self.qkv = nn.Linear(cfg.embed_dim, 3 * cfg.embed_dim, bias=cfg.bias)
self.proj = nn.Linear(cfg.embed_dim, cfg.embed_dim, bias=cfg.bias)
self.resid_drop = nn.Dropout(cfg.dropout)
self.dropout_p = cfg.dropout
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
out = F.scaled_dot_product_attention(
q, k, v, is_causal=True,
dropout_p=self.dropout_p if self.training else 0.0,
)
out = out.transpose(1, 2).reshape(B, T, C)
return self.resid_drop(self.proj(out))
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
hidden = 4 * cfg.embed_dim
self.fc1 = nn.Linear(cfg.embed_dim, hidden, bias=cfg.bias)
self.fc2 = nn.Linear(hidden, cfg.embed_dim, bias=cfg.bias)
self.drop = nn.Dropout(cfg.dropout)
def forward(self, x):
return self.drop(self.fc2(F.gelu(self.fc1(x))))
class Block(nn.Module):
def __init__(self, cfg):
super().__init__()
self.ln1 = nn.LayerNorm(cfg.embed_dim)
self.attn = CausalSelfAttention(cfg)
self.ln2 = nn.LayerNorm(cfg.embed_dim)
self.ffn = FeedForward(cfg)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class SOVYN85M(nn.Module):
def __init__(self, cfg=None):
super().__init__()
if cfg is None:
cfg = ModelConfig()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.embed_dim)
self.pos_emb = nn.Embedding(cfg.context_length, cfg.embed_dim)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.num_layers)])
self.ln_f = nn.LayerNorm(cfg.embed_dim)
self.head = nn.Linear(cfg.embed_dim, cfg.vocab_size, bias=False)
self.head.weight = self.tok_emb.weight
@property
def num_params(self):
return sum(p.numel() for p in self.parameters()) - self.tok_emb.weight.numel()
def forward(self, idx, targets=None):
B, T = idx.shape
pos = torch.arange(T, device=idx.device)
x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
ignore_index=0)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=0.8, top_k=50):
self.eval()
for _ in range(max_new_tokens):
ctx = idx[:, -self.cfg.context_length:]
logits, _ = self(ctx)
logits = logits[:, -1, :] / max(temperature, 1e-8)
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, -1:]] = float('-inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, 1)
idx = torch.cat([idx, nxt], dim=1)
if nxt.item() == 2: # EOS
break
return idx