| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import asdict |
| from .config import TinyConfig |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-5): |
| super().__init__(); self.eps=eps; self.weight=nn.Parameter(torch.ones(dim)) |
| def forward(self, x): |
| return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def rotate_half(x): |
| x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_position_embeddings=2048, base=10000.0): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| t = torch.arange(max_position_embeddings, dtype=torch.float) |
| freqs = torch.einsum('i,j->ij', t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False) |
| self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False) |
| def forward(self, q, k, seq_len): |
| cos = self.cos_cached[:, :, :seq_len, :].to(q.device, q.dtype) |
| sin = self.sin_cached[:, :, :seq_len, :].to(q.device, q.dtype) |
| return (q*cos + rotate_half(q)*sin), (k*cos + rotate_half(k)*sin) |
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, cfg: TinyConfig): |
| super().__init__() |
| assert cfg.hidden_size % cfg.num_attention_heads == 0 |
| assert cfg.num_attention_heads % cfg.num_key_value_heads == 0 |
| self.nh=cfg.num_attention_heads; self.nkv=cfg.num_key_value_heads |
| self.hd=cfg.hidden_size//cfg.num_attention_heads |
| self.q_proj=nn.Linear(cfg.hidden_size, self.nh*self.hd, bias=cfg.attention_bias) |
| self.k_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias) |
| self.v_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias) |
| self.o_proj=nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=cfg.attention_bias) |
| self.rotary=RotaryEmbedding(self.hd, cfg.max_position_embeddings, cfg.rope_theta) |
| self.dropout=cfg.dropout |
| def forward(self, x): |
| B,T,C=x.shape |
| q=self.q_proj(x).view(B,T,self.nh,self.hd).transpose(1,2) |
| k=self.k_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2) |
| v=self.v_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2) |
| q,k=self.rotary(q,k,T) |
| if self.nkv != self.nh: |
| repeat = self.nh // self.nkv |
| k = k.repeat_interleave(repeat, dim=1) |
| v = v.repeat_interleave(repeat, dim=1) |
| y=F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout if self.training else 0.0,is_causal=True) |
| y=y.transpose(1,2).contiguous().view(B,T,C) |
| return self.o_proj(y) |
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, cfg: TinyConfig): |
| super().__init__() |
| self.gate_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias) |
| self.up_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias) |
| self.down_proj=nn.Linear(cfg.intermediate_size,cfg.hidden_size,bias=cfg.mlp_bias) |
| def forward(self,x): |
| return self.down_proj(F.silu(self.gate_proj(x))*self.up_proj(x)) |
|
|
| class Block(nn.Module): |
| def __init__(self,cfg): |
| super().__init__(); self.input_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.attn=CausalSelfAttention(cfg); self.post_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.mlp=SwiGLU(cfg) |
| def forward(self,x): |
| x=x+self.attn(self.input_norm(x)); x=x+self.mlp(self.post_norm(x)); return x |
|
|
| class TinyLlamaForCausalLM(nn.Module): |
| def __init__(self,cfg:TinyConfig): |
| super().__init__(); self.config=cfg |
| self.embed_tokens=nn.Embedding(cfg.vocab_size,cfg.hidden_size) |
| self.layers=nn.ModuleList([Block(cfg) for _ in range(cfg.num_hidden_layers)]) |
| self.norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps) |
| self.lm_head=nn.Linear(cfg.hidden_size,cfg.vocab_size,bias=False) |
| if cfg.tie_word_embeddings: |
| self.lm_head.weight=self.embed_tokens.weight |
| self.apply(self._init_weights) |
| def _init_weights(self,m): |
| if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| if isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| def forward(self,input_ids,labels=None,loss_mask=None): |
| x=self.embed_tokens(input_ids) |
| for layer in self.layers: x=layer(x) |
| logits=self.lm_head(self.norm(x)) |
| loss=None |
| if labels is not None: |
| shift_logits=logits[:,:-1,:].contiguous(); shift_labels=labels[:,1:].contiguous() |
| per=F.cross_entropy(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1),reduction='none') |
| if loss_mask is not None: |
| mask=loss_mask[:,1:].contiguous().view(-1).float(); loss=(per*mask).sum()/mask.sum().clamp_min(1.0) |
| else: loss=per.mean() |
| return {'loss':loss,'logits':logits} |
| def num_parameters(self): return sum(p.numel() for p in self.parameters()) |
| def save_config(self,path): |
| import json; open(path,'w').write(json.dumps(asdict(self.config),indent=2)) |
|
|