| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class SovynConfig: |
| name: str = "SOVYN-120M-Cortex" |
| vocab_size: int = 32000 |
| max_seq_len: int = 1024 |
| n_layers: int = 12 |
| hidden_size: int = 768 |
| n_heads: int = 12 |
| n_kv_heads: int = 4 |
| ffn_size: int = 2688 |
| dropout: float = 0.0 |
| rope_theta: float = 10000.0 |
| tie_embeddings: bool = True |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| normed = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
| return normed * self.weight |
|
|
|
|
| def precompute_rope(head_dim: int, max_seq_len: int, theta: float): |
| inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| t = torch.arange(max_seq_len).float() |
| freqs = torch.outer(t, inv_freq) |
| return torch.cos(freqs), torch.sin(freqs) |
|
|
|
|
| def apply_rope(x, cos, sin): |
| cos = cos[None, :, None, :] |
| sin = sin[None, :, None, :] |
| x_even = x[..., 0::2] |
| x_odd = x[..., 1::2] |
| out = torch.empty_like(x) |
| out[..., 0::2] = x_even * cos - x_odd * sin |
| out[..., 1::2] = x_even * sin + x_odd * cos |
| return out |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, cfg: SovynConfig): |
| super().__init__() |
| if cfg.n_heads % cfg.n_kv_heads != 0: |
| raise ValueError("n_heads must be divisible by n_kv_heads") |
| self.n_heads = cfg.n_heads |
| self.n_kv_heads = cfg.n_kv_heads |
| self.head_dim = cfg.hidden_size // cfg.n_heads |
| self.repeat = cfg.n_heads // cfg.n_kv_heads |
|
|
| kv_dim = cfg.n_kv_heads * self.head_dim |
| self.q_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) |
| self.k_proj = nn.Linear(cfg.hidden_size, kv_dim, bias=False) |
| self.v_proj = nn.Linear(cfg.hidden_size, kv_dim, bias=False) |
| self.o_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) |
| self.dropout = cfg.dropout |
|
|
| def forward(self, x, cos, sin): |
| bsz, seq_len, hidden = x.shape |
| q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim) |
| k = self.k_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim) |
| v = self.v_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim) |
|
|
| q = apply_rope(q, cos[:seq_len], sin[:seq_len]) |
| k = apply_rope(k, cos[:seq_len], sin[:seq_len]) |
|
|
| k = k.repeat_interleave(self.repeat, dim=2) |
| v = v.repeat_interleave(self.repeat, dim=2) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| y = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=True, |
| ) |
| y = y.transpose(1, 2).contiguous().view(bsz, seq_len, hidden) |
| return self.o_proj(y) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, cfg: SovynConfig): |
| super().__init__() |
| self.gate = nn.Linear(cfg.hidden_size, cfg.ffn_size, bias=False) |
| self.up = nn.Linear(cfg.hidden_size, cfg.ffn_size, bias=False) |
| self.down = nn.Linear(cfg.ffn_size, cfg.hidden_size, bias=False) |
|
|
| def forward(self, x): |
| return self.down(F.silu(self.gate(x)) * self.up(x)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, cfg: SovynConfig): |
| super().__init__() |
| self.attn_norm = RMSNorm(cfg.hidden_size) |
| self.attn = Attention(cfg) |
| self.ffn_norm = RMSNorm(cfg.hidden_size) |
| self.ffn = SwiGLU(cfg) |
|
|
| def forward(self, x, cos, sin): |
| x = x + self.attn(self.attn_norm(x), cos, sin) |
| x = x + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class SovynForCausalLM(nn.Module): |
| def __init__(self, cfg: SovynConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size) |
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) |
| self.norm = RMSNorm(cfg.hidden_size) |
| self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) |
| if cfg.tie_embeddings: |
| self.lm_head.weight = self.embed.weight |
|
|
| cos, sin = precompute_rope( |
| cfg.hidden_size // cfg.n_heads, |
| cfg.max_seq_len, |
| cfg.rope_theta, |
| ) |
| self.register_buffer("rope_cos", cos, persistent=False) |
| self.register_buffer("rope_sin", sin, persistent=False) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, input_ids, labels=None): |
| if input_ids.size(1) > self.cfg.max_seq_len: |
| raise ValueError("Sequence length exceeds max_seq_len") |
|
|
| x = self.embed(input_ids) |
| for block in self.blocks: |
| x = block(x, self.rope_cos, self.rope_sin) |
| x = self.norm(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| labels.view(-1), |
| ignore_index=-100, |
| ) |
| return {"loss": loss, "logits": logits} |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids, |
| max_new_tokens=96, |
| temperature=0.8, |
| top_k=50, |
| eos_id=None, |
| stop_ids=None, |
| suppress_ids=None, |
| ): |
| self.eval() |
| stop_ids = set(stop_ids or []) |
| suppress_ids = list(suppress_ids or []) |
| for _ in range(max_new_tokens): |
| x = input_ids[:, -self.cfg.max_seq_len :] |
| logits = self(x)["logits"][:, -1, :] |
| if suppress_ids: |
| logits[:, suppress_ids] = -float("inf") |
| if temperature <= 0: |
| next_id = torch.argmax(logits, dim=-1, keepdim=True) |
| else: |
| logits = logits / temperature |
| if top_k > 0: |
| values, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < values[:, [-1]]] = -float("inf") |
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| input_ids = torch.cat([input_ids, next_id], dim=1) |
| token_id = next_id.item() |
| if eos_id is not None and token_id == eos_id: |
| break |
| if token_id in stop_ids: |
| break |
| return input_ids |
|
|