SOVYN's picture
Upload folder using huggingface_hub
681909f verified
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class SovynConfig:
name: str = "SOVYN-120M-Cortex"
vocab_size: int = 32000
max_seq_len: int = 1024
n_layers: int = 12
hidden_size: int = 768
n_heads: int = 12
n_kv_heads: int = 4
ffn_size: int = 2688
dropout: float = 0.0
rope_theta: float = 10000.0
tie_embeddings: bool = True
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
normed = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return normed * self.weight
def precompute_rope(head_dim: int, max_seq_len: int, theta: float):
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len).float()
freqs = torch.outer(t, inv_freq)
return torch.cos(freqs), torch.sin(freqs)
def apply_rope(x, cos, sin):
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
out = torch.empty_like(x)
out[..., 0::2] = x_even * cos - x_odd * sin
out[..., 1::2] = x_even * sin + x_odd * cos
return out
class Attention(nn.Module):
def __init__(self, cfg: SovynConfig):
super().__init__()
if cfg.n_heads % cfg.n_kv_heads != 0:
raise ValueError("n_heads must be divisible by n_kv_heads")
self.n_heads = cfg.n_heads
self.n_kv_heads = cfg.n_kv_heads
self.head_dim = cfg.hidden_size // cfg.n_heads
self.repeat = cfg.n_heads // cfg.n_kv_heads
kv_dim = cfg.n_kv_heads * self.head_dim
self.q_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
self.k_proj = nn.Linear(cfg.hidden_size, kv_dim, bias=False)
self.v_proj = nn.Linear(cfg.hidden_size, kv_dim, bias=False)
self.o_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
self.dropout = cfg.dropout
def forward(self, x, cos, sin):
bsz, seq_len, hidden = x.shape
q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim)
k = self.k_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
v = self.v_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim)
q = apply_rope(q, cos[:seq_len], sin[:seq_len])
k = apply_rope(k, cos[:seq_len], sin[:seq_len])
k = k.repeat_interleave(self.repeat, dim=2)
v = v.repeat_interleave(self.repeat, dim=2)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True,
)
y = y.transpose(1, 2).contiguous().view(bsz, seq_len, hidden)
return self.o_proj(y)
class SwiGLU(nn.Module):
def __init__(self, cfg: SovynConfig):
super().__init__()
self.gate = nn.Linear(cfg.hidden_size, cfg.ffn_size, bias=False)
self.up = nn.Linear(cfg.hidden_size, cfg.ffn_size, bias=False)
self.down = nn.Linear(cfg.ffn_size, cfg.hidden_size, bias=False)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
class Block(nn.Module):
def __init__(self, cfg: SovynConfig):
super().__init__()
self.attn_norm = RMSNorm(cfg.hidden_size)
self.attn = Attention(cfg)
self.ffn_norm = RMSNorm(cfg.hidden_size)
self.ffn = SwiGLU(cfg)
def forward(self, x, cos, sin):
x = x + self.attn(self.attn_norm(x), cos, sin)
x = x + self.ffn(self.ffn_norm(x))
return x
class SovynForCausalLM(nn.Module):
def __init__(self, cfg: SovynConfig):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
self.norm = RMSNorm(cfg.hidden_size)
self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
if cfg.tie_embeddings:
self.lm_head.weight = self.embed.weight
cos, sin = precompute_rope(
cfg.hidden_size // cfg.n_heads,
cfg.max_seq_len,
cfg.rope_theta,
)
self.register_buffer("rope_cos", cos, persistent=False)
self.register_buffer("rope_sin", sin, persistent=False)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids, labels=None):
if input_ids.size(1) > self.cfg.max_seq_len:
raise ValueError("Sequence length exceeds max_seq_len")
x = self.embed(input_ids)
for block in self.blocks:
x = block(x, self.rope_cos, self.rope_sin)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100,
)
return {"loss": loss, "logits": logits}
@torch.no_grad()
def generate(
self,
input_ids,
max_new_tokens=96,
temperature=0.8,
top_k=50,
eos_id=None,
stop_ids=None,
suppress_ids=None,
):
self.eval()
stop_ids = set(stop_ids or [])
suppress_ids = list(suppress_ids or [])
for _ in range(max_new_tokens):
x = input_ids[:, -self.cfg.max_seq_len :]
logits = self(x)["logits"][:, -1, :]
if suppress_ids:
logits[:, suppress_ids] = -float("inf")
if temperature <= 0:
next_id = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = logits / temperature
if top_k > 0:
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < values[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_id], dim=1)
token_id = next_id.item()
if eos_id is not None and token_id == eos_id:
break
if token_id in stop_ids:
break
return input_ids