| """
|
| SOVYN-85M 모델 아키텍처
|
| https://huggingface.co/SOVYN/SOVYN-85M
|
| """
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
|
|
|
|
| class ModelConfig:
|
| vocab_size: int = 16384
|
| context_length: int = 512
|
| embed_dim: int = 768
|
| num_heads: int = 12
|
| num_layers: int = 12
|
| dropout: float = 0.1
|
| bias: bool = False
|
|
|
|
|
| class CausalSelfAttention(nn.Module):
|
| def __init__(self, cfg):
|
| super().__init__()
|
| self.num_heads = cfg.num_heads
|
| self.head_dim = cfg.embed_dim // cfg.num_heads
|
| self.embed_dim = cfg.embed_dim
|
| self.qkv = nn.Linear(cfg.embed_dim, 3 * cfg.embed_dim, bias=cfg.bias)
|
| self.proj = nn.Linear(cfg.embed_dim, cfg.embed_dim, bias=cfg.bias)
|
| self.resid_drop = nn.Dropout(cfg.dropout)
|
| self.dropout_p = cfg.dropout
|
|
|
| def forward(self, x):
|
| B, T, C = x.shape
|
| qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
|
| qkv = qkv.permute(2, 0, 3, 1, 4)
|
| q, k, v = qkv.unbind(0)
|
| out = F.scaled_dot_product_attention(
|
| q, k, v, is_causal=True,
|
| dropout_p=self.dropout_p if self.training else 0.0,
|
| )
|
| out = out.transpose(1, 2).reshape(B, T, C)
|
| return self.resid_drop(self.proj(out))
|
|
|
|
|
| class FeedForward(nn.Module):
|
| def __init__(self, cfg):
|
| super().__init__()
|
| hidden = 4 * cfg.embed_dim
|
| self.fc1 = nn.Linear(cfg.embed_dim, hidden, bias=cfg.bias)
|
| self.fc2 = nn.Linear(hidden, cfg.embed_dim, bias=cfg.bias)
|
| self.drop = nn.Dropout(cfg.dropout)
|
|
|
| def forward(self, x):
|
| return self.drop(self.fc2(F.gelu(self.fc1(x))))
|
|
|
|
|
| class Block(nn.Module):
|
| def __init__(self, cfg):
|
| super().__init__()
|
| self.ln1 = nn.LayerNorm(cfg.embed_dim)
|
| self.attn = CausalSelfAttention(cfg)
|
| self.ln2 = nn.LayerNorm(cfg.embed_dim)
|
| self.ffn = FeedForward(cfg)
|
|
|
| def forward(self, x):
|
| x = x + self.attn(self.ln1(x))
|
| x = x + self.ffn(self.ln2(x))
|
| return x
|
|
|
|
|
| class SOVYN85M(nn.Module):
|
| def __init__(self, cfg=None):
|
| super().__init__()
|
| if cfg is None:
|
| cfg = ModelConfig()
|
| self.cfg = cfg
|
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.embed_dim)
|
| self.pos_emb = nn.Embedding(cfg.context_length, cfg.embed_dim)
|
| self.drop = nn.Dropout(cfg.dropout)
|
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.num_layers)])
|
| self.ln_f = nn.LayerNorm(cfg.embed_dim)
|
| self.head = nn.Linear(cfg.embed_dim, cfg.vocab_size, bias=False)
|
| self.head.weight = self.tok_emb.weight
|
|
|
| @property
|
| def num_params(self):
|
| return sum(p.numel() for p in self.parameters()) - self.tok_emb.weight.numel()
|
|
|
| def forward(self, idx, targets=None):
|
| B, T = idx.shape
|
| pos = torch.arange(T, device=idx.device)
|
| x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
|
| for block in self.blocks:
|
| x = block(x)
|
| x = self.ln_f(x)
|
| logits = self.head(x)
|
| loss = None
|
| if targets is not None:
|
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
|
| ignore_index=0)
|
| return logits, loss
|
|
|
| @torch.no_grad()
|
| def generate(self, idx, max_new_tokens=200, temperature=0.8, top_k=50):
|
| self.eval()
|
| for _ in range(max_new_tokens):
|
| ctx = idx[:, -self.cfg.context_length:]
|
| logits, _ = self(ctx)
|
| logits = logits[:, -1, :] / max(temperature, 1e-8)
|
| if top_k > 0:
|
| v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| logits[logits < v[:, -1:]] = float('-inf')
|
| probs = F.softmax(logits, dim=-1)
|
| nxt = torch.multinomial(probs, 1)
|
| idx = torch.cat([idx, nxt], dim=1)
|
| if nxt.item() == 2:
|
| break
|
| return idx
|
|
|