""" SOVYN-85M 모델 아키텍처 https://huggingface.co/SOVYN/SOVYN-85M """ import torch import torch.nn as nn import torch.nn.functional as F import math class ModelConfig: vocab_size: int = 16384 context_length: int = 512 embed_dim: int = 768 num_heads: int = 12 num_layers: int = 12 dropout: float = 0.1 bias: bool = False class CausalSelfAttention(nn.Module): def __init__(self, cfg): super().__init__() self.num_heads = cfg.num_heads self.head_dim = cfg.embed_dim // cfg.num_heads self.embed_dim = cfg.embed_dim self.qkv = nn.Linear(cfg.embed_dim, 3 * cfg.embed_dim, bias=cfg.bias) self.proj = nn.Linear(cfg.embed_dim, cfg.embed_dim, bias=cfg.bias) self.resid_drop = nn.Dropout(cfg.dropout) self.dropout_p = cfg.dropout def forward(self, x): B, T, C = x.shape qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) out = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=self.dropout_p if self.training else 0.0, ) out = out.transpose(1, 2).reshape(B, T, C) return self.resid_drop(self.proj(out)) class FeedForward(nn.Module): def __init__(self, cfg): super().__init__() hidden = 4 * cfg.embed_dim self.fc1 = nn.Linear(cfg.embed_dim, hidden, bias=cfg.bias) self.fc2 = nn.Linear(hidden, cfg.embed_dim, bias=cfg.bias) self.drop = nn.Dropout(cfg.dropout) def forward(self, x): return self.drop(self.fc2(F.gelu(self.fc1(x)))) class Block(nn.Module): def __init__(self, cfg): super().__init__() self.ln1 = nn.LayerNorm(cfg.embed_dim) self.attn = CausalSelfAttention(cfg) self.ln2 = nn.LayerNorm(cfg.embed_dim) self.ffn = FeedForward(cfg) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.ffn(self.ln2(x)) return x class SOVYN85M(nn.Module): def __init__(self, cfg=None): super().__init__() if cfg is None: cfg = ModelConfig() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.embed_dim) self.pos_emb = nn.Embedding(cfg.context_length, cfg.embed_dim) self.drop = nn.Dropout(cfg.dropout) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.num_layers)]) self.ln_f = nn.LayerNorm(cfg.embed_dim) self.head = nn.Linear(cfg.embed_dim, cfg.vocab_size, bias=False) self.head.weight = self.tok_emb.weight @property def num_params(self): return sum(p.numel() for p in self.parameters()) - self.tok_emb.weight.numel() def forward(self, idx, targets=None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) for block in self.blocks: x = block(x) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens=200, temperature=0.8, top_k=50): self.eval() for _ in range(max_new_tokens): ctx = idx[:, -self.cfg.context_length:] logits, _ = self(ctx) logits = logits[:, -1, :] / max(temperature, 1e-8) if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, -1:]] = float('-inf') probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, 1) idx = torch.cat([idx, nxt], dim=1) if nxt.item() == 2: # EOS break return idx