razor5050's picture
Add tokenizer, inference code, model card, and 20-query report
ca2f8ca verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import asdict
from .config import TinyConfig
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__(); self.eps=eps; self.weight=nn.Parameter(torch.ones(dim))
def forward(self, x):
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum('i,j->ij', t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
def forward(self, q, k, seq_len):
cos = self.cos_cached[:, :, :seq_len, :].to(q.device, q.dtype)
sin = self.sin_cached[:, :, :seq_len, :].to(q.device, q.dtype)
return (q*cos + rotate_half(q)*sin), (k*cos + rotate_half(k)*sin)
class CausalSelfAttention(nn.Module):
def __init__(self, cfg: TinyConfig):
super().__init__()
assert cfg.hidden_size % cfg.num_attention_heads == 0
assert cfg.num_attention_heads % cfg.num_key_value_heads == 0
self.nh=cfg.num_attention_heads; self.nkv=cfg.num_key_value_heads
self.hd=cfg.hidden_size//cfg.num_attention_heads
self.q_proj=nn.Linear(cfg.hidden_size, self.nh*self.hd, bias=cfg.attention_bias)
self.k_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias)
self.v_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias)
self.o_proj=nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=cfg.attention_bias)
self.rotary=RotaryEmbedding(self.hd, cfg.max_position_embeddings, cfg.rope_theta)
self.dropout=cfg.dropout
def forward(self, x):
B,T,C=x.shape
q=self.q_proj(x).view(B,T,self.nh,self.hd).transpose(1,2)
k=self.k_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
v=self.v_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
q,k=self.rotary(q,k,T)
if self.nkv != self.nh:
repeat = self.nh // self.nkv
k = k.repeat_interleave(repeat, dim=1)
v = v.repeat_interleave(repeat, dim=1)
y=F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout if self.training else 0.0,is_causal=True)
y=y.transpose(1,2).contiguous().view(B,T,C)
return self.o_proj(y)
class SwiGLU(nn.Module):
def __init__(self, cfg: TinyConfig):
super().__init__()
self.gate_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias)
self.up_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias)
self.down_proj=nn.Linear(cfg.intermediate_size,cfg.hidden_size,bias=cfg.mlp_bias)
def forward(self,x):
return self.down_proj(F.silu(self.gate_proj(x))*self.up_proj(x))
class Block(nn.Module):
def __init__(self,cfg):
super().__init__(); self.input_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.attn=CausalSelfAttention(cfg); self.post_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.mlp=SwiGLU(cfg)
def forward(self,x):
x=x+self.attn(self.input_norm(x)); x=x+self.mlp(self.post_norm(x)); return x
class TinyLlamaForCausalLM(nn.Module):
def __init__(self,cfg:TinyConfig):
super().__init__(); self.config=cfg
self.embed_tokens=nn.Embedding(cfg.vocab_size,cfg.hidden_size)
self.layers=nn.ModuleList([Block(cfg) for _ in range(cfg.num_hidden_layers)])
self.norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps)
self.lm_head=nn.Linear(cfg.hidden_size,cfg.vocab_size,bias=False)
if cfg.tie_word_embeddings:
self.lm_head.weight=self.embed_tokens.weight
self.apply(self._init_weights)
def _init_weights(self,m):
if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02)
if isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02)
def forward(self,input_ids,labels=None,loss_mask=None):
x=self.embed_tokens(input_ids)
for layer in self.layers: x=layer(x)
logits=self.lm_head(self.norm(x))
loss=None
if labels is not None:
shift_logits=logits[:,:-1,:].contiguous(); shift_labels=labels[:,1:].contiguous()
per=F.cross_entropy(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1),reduction='none')
if loss_mask is not None:
mask=loss_mask[:,1:].contiguous().view(-1).float(); loss=(per*mask).sum()/mask.sum().clamp_min(1.0)
else: loss=per.mean()
return {'loss':loss,'logits':logits}
def num_parameters(self): return sum(p.numel() for p in self.parameters())
def save_config(self,path):
import json; open(path,'w').write(json.dumps(asdict(self.config),indent=2))