import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import asdict from .config import TinyConfig class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__(); self.eps=eps; self.weight=nn.Parameter(torch.ones(dim)) def forward(self, x): return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def rotate_half(x): x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] return torch.cat((-x2, x1), dim=-1) class RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum('i,j->ij', t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False) self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False) def forward(self, q, k, seq_len): cos = self.cos_cached[:, :, :seq_len, :].to(q.device, q.dtype) sin = self.sin_cached[:, :, :seq_len, :].to(q.device, q.dtype) return (q*cos + rotate_half(q)*sin), (k*cos + rotate_half(k)*sin) class CausalSelfAttention(nn.Module): def __init__(self, cfg: TinyConfig): super().__init__() assert cfg.hidden_size % cfg.num_attention_heads == 0 assert cfg.num_attention_heads % cfg.num_key_value_heads == 0 self.nh=cfg.num_attention_heads; self.nkv=cfg.num_key_value_heads self.hd=cfg.hidden_size//cfg.num_attention_heads self.q_proj=nn.Linear(cfg.hidden_size, self.nh*self.hd, bias=cfg.attention_bias) self.k_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias) self.v_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias) self.o_proj=nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=cfg.attention_bias) self.rotary=RotaryEmbedding(self.hd, cfg.max_position_embeddings, cfg.rope_theta) self.dropout=cfg.dropout def forward(self, x): B,T,C=x.shape q=self.q_proj(x).view(B,T,self.nh,self.hd).transpose(1,2) k=self.k_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2) v=self.v_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2) q,k=self.rotary(q,k,T) if self.nkv != self.nh: repeat = self.nh // self.nkv k = k.repeat_interleave(repeat, dim=1) v = v.repeat_interleave(repeat, dim=1) y=F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout if self.training else 0.0,is_causal=True) y=y.transpose(1,2).contiguous().view(B,T,C) return self.o_proj(y) class SwiGLU(nn.Module): def __init__(self, cfg: TinyConfig): super().__init__() self.gate_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias) self.up_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias) self.down_proj=nn.Linear(cfg.intermediate_size,cfg.hidden_size,bias=cfg.mlp_bias) def forward(self,x): return self.down_proj(F.silu(self.gate_proj(x))*self.up_proj(x)) class Block(nn.Module): def __init__(self,cfg): super().__init__(); self.input_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.attn=CausalSelfAttention(cfg); self.post_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.mlp=SwiGLU(cfg) def forward(self,x): x=x+self.attn(self.input_norm(x)); x=x+self.mlp(self.post_norm(x)); return x class TinyLlamaForCausalLM(nn.Module): def __init__(self,cfg:TinyConfig): super().__init__(); self.config=cfg self.embed_tokens=nn.Embedding(cfg.vocab_size,cfg.hidden_size) self.layers=nn.ModuleList([Block(cfg) for _ in range(cfg.num_hidden_layers)]) self.norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps) self.lm_head=nn.Linear(cfg.hidden_size,cfg.vocab_size,bias=False) if cfg.tie_word_embeddings: self.lm_head.weight=self.embed_tokens.weight self.apply(self._init_weights) def _init_weights(self,m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def forward(self,input_ids,labels=None,loss_mask=None): x=self.embed_tokens(input_ids) for layer in self.layers: x=layer(x) logits=self.lm_head(self.norm(x)) loss=None if labels is not None: shift_logits=logits[:,:-1,:].contiguous(); shift_labels=labels[:,1:].contiguous() per=F.cross_entropy(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1),reduction='none') if loss_mask is not None: mask=loss_mask[:,1:].contiguous().view(-1).float(); loss=(per*mask).sum()/mask.sum().clamp_min(1.0) else: loss=per.mean() return {'loss':loss,'logits':logits} def num_parameters(self): return sum(p.numel() for p in self.parameters()) def save_config(self,path): import json; open(path,'w').write(json.dumps(asdict(self.config),indent=2))