from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @dataclass class SovynConfig: name: str = "SOVYN-120M-Cortex" vocab_size: int = 32000 max_seq_len: int = 1024 n_layers: int = 12 hidden_size: int = 768 n_heads: int = 12 n_kv_heads: int = 4 ffn_size: int = 2688 dropout: float = 0.0 rope_theta: float = 10000.0 tie_embeddings: bool = True class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): normed = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return normed * self.weight def precompute_rope(head_dim: int, max_seq_len: int, theta: float): inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) t = torch.arange(max_seq_len).float() freqs = torch.outer(t, inv_freq) return torch.cos(freqs), torch.sin(freqs) def apply_rope(x, cos, sin): cos = cos[None, :, None, :] sin = sin[None, :, None, :] x_even = x[..., 0::2] x_odd = x[..., 1::2] out = torch.empty_like(x) out[..., 0::2] = x_even * cos - x_odd * sin out[..., 1::2] = x_even * sin + x_odd * cos return out class Attention(nn.Module): def __init__(self, cfg: SovynConfig): super().__init__() if cfg.n_heads % cfg.n_kv_heads != 0: raise ValueError("n_heads must be divisible by n_kv_heads") self.n_heads = cfg.n_heads self.n_kv_heads = cfg.n_kv_heads self.head_dim = cfg.hidden_size // cfg.n_heads self.repeat = cfg.n_heads // cfg.n_kv_heads kv_dim = cfg.n_kv_heads * self.head_dim self.q_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) self.k_proj = nn.Linear(cfg.hidden_size, kv_dim, bias=False) self.v_proj = nn.Linear(cfg.hidden_size, kv_dim, bias=False) self.o_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) self.dropout = cfg.dropout def forward(self, x, cos, sin): bsz, seq_len, hidden = x.shape q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim) k = self.k_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim) v = self.v_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim) q = apply_rope(q, cos[:seq_len], sin[:seq_len]) k = apply_rope(k, cos[:seq_len], sin[:seq_len]) k = k.repeat_interleave(self.repeat, dim=2) v = v.repeat_interleave(self.repeat, dim=2) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) y = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True, ) y = y.transpose(1, 2).contiguous().view(bsz, seq_len, hidden) return self.o_proj(y) class SwiGLU(nn.Module): def __init__(self, cfg: SovynConfig): super().__init__() self.gate = nn.Linear(cfg.hidden_size, cfg.ffn_size, bias=False) self.up = nn.Linear(cfg.hidden_size, cfg.ffn_size, bias=False) self.down = nn.Linear(cfg.ffn_size, cfg.hidden_size, bias=False) def forward(self, x): return self.down(F.silu(self.gate(x)) * self.up(x)) class Block(nn.Module): def __init__(self, cfg: SovynConfig): super().__init__() self.attn_norm = RMSNorm(cfg.hidden_size) self.attn = Attention(cfg) self.ffn_norm = RMSNorm(cfg.hidden_size) self.ffn = SwiGLU(cfg) def forward(self, x, cos, sin): x = x + self.attn(self.attn_norm(x), cos, sin) x = x + self.ffn(self.ffn_norm(x)) return x class SovynForCausalLM(nn.Module): def __init__(self, cfg: SovynConfig): super().__init__() self.cfg = cfg self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) self.norm = RMSNorm(cfg.hidden_size) self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) if cfg.tie_embeddings: self.lm_head.weight = self.embed.weight cos, sin = precompute_rope( cfg.hidden_size // cfg.n_heads, cfg.max_seq_len, cfg.rope_theta, ) self.register_buffer("rope_cos", cos, persistent=False) self.register_buffer("rope_sin", sin, persistent=False) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, input_ids, labels=None): if input_ids.size(1) > self.cfg.max_seq_len: raise ValueError("Sequence length exceeds max_seq_len") x = self.embed(input_ids) for block in self.blocks: x = block(x, self.rope_cos, self.rope_sin) x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, ) return {"loss": loss, "logits": logits} @torch.no_grad() def generate( self, input_ids, max_new_tokens=96, temperature=0.8, top_k=50, eos_id=None, stop_ids=None, suppress_ids=None, ): self.eval() stop_ids = set(stop_ids or []) suppress_ids = list(suppress_ids or []) for _ in range(max_new_tokens): x = input_ids[:, -self.cfg.max_seq_len :] logits = self(x)["logits"][:, -1, :] if suppress_ids: logits[:, suppress_ids] = -float("inf") if temperature <= 0: next_id = torch.argmax(logits, dim=-1, keepdim=True) else: logits = logits / temperature if top_k > 0: values, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < values[:, [-1]]] = -float("inf") probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_id], dim=1) token_id = next_id.item() if eos_id is not None and token_id == eos_id: break if token_id in stop_ids: break return input_ids