File size: 5,274 Bytes
ca2f8ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import asdict
from .config import TinyConfig

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__(); self.eps=eps; self.weight=nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000.0):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_position_embeddings, dtype=torch.float)
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
    def forward(self, q, k, seq_len):
        cos = self.cos_cached[:, :, :seq_len, :].to(q.device, q.dtype)
        sin = self.sin_cached[:, :, :seq_len, :].to(q.device, q.dtype)
        return (q*cos + rotate_half(q)*sin), (k*cos + rotate_half(k)*sin)

class CausalSelfAttention(nn.Module):
    def __init__(self, cfg: TinyConfig):
        super().__init__()
        assert cfg.hidden_size % cfg.num_attention_heads == 0
        assert cfg.num_attention_heads % cfg.num_key_value_heads == 0
        self.nh=cfg.num_attention_heads; self.nkv=cfg.num_key_value_heads
        self.hd=cfg.hidden_size//cfg.num_attention_heads
        self.q_proj=nn.Linear(cfg.hidden_size, self.nh*self.hd, bias=cfg.attention_bias)
        self.k_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias)
        self.v_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias)
        self.o_proj=nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=cfg.attention_bias)
        self.rotary=RotaryEmbedding(self.hd, cfg.max_position_embeddings, cfg.rope_theta)
        self.dropout=cfg.dropout
    def forward(self, x):
        B,T,C=x.shape
        q=self.q_proj(x).view(B,T,self.nh,self.hd).transpose(1,2)
        k=self.k_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
        v=self.v_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
        q,k=self.rotary(q,k,T)
        if self.nkv != self.nh:
            repeat = self.nh // self.nkv
            k = k.repeat_interleave(repeat, dim=1)
            v = v.repeat_interleave(repeat, dim=1)
        y=F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout if self.training else 0.0,is_causal=True)
        y=y.transpose(1,2).contiguous().view(B,T,C)
        return self.o_proj(y)

class SwiGLU(nn.Module):
    def __init__(self, cfg: TinyConfig):
        super().__init__()
        self.gate_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias)
        self.up_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias)
        self.down_proj=nn.Linear(cfg.intermediate_size,cfg.hidden_size,bias=cfg.mlp_bias)
    def forward(self,x):
        return self.down_proj(F.silu(self.gate_proj(x))*self.up_proj(x))

class Block(nn.Module):
    def __init__(self,cfg):
        super().__init__(); self.input_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.attn=CausalSelfAttention(cfg); self.post_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.mlp=SwiGLU(cfg)
    def forward(self,x):
        x=x+self.attn(self.input_norm(x)); x=x+self.mlp(self.post_norm(x)); return x

class TinyLlamaForCausalLM(nn.Module):
    def __init__(self,cfg:TinyConfig):
        super().__init__(); self.config=cfg
        self.embed_tokens=nn.Embedding(cfg.vocab_size,cfg.hidden_size)
        self.layers=nn.ModuleList([Block(cfg) for _ in range(cfg.num_hidden_layers)])
        self.norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps)
        self.lm_head=nn.Linear(cfg.hidden_size,cfg.vocab_size,bias=False)
        if cfg.tie_word_embeddings:
            self.lm_head.weight=self.embed_tokens.weight
        self.apply(self._init_weights)
    def _init_weights(self,m):
        if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02)
    def forward(self,input_ids,labels=None,loss_mask=None):
        x=self.embed_tokens(input_ids)
        for layer in self.layers: x=layer(x)
        logits=self.lm_head(self.norm(x))
        loss=None
        if labels is not None:
            shift_logits=logits[:,:-1,:].contiguous(); shift_labels=labels[:,1:].contiguous()
            per=F.cross_entropy(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1),reduction='none')
            if loss_mask is not None:
                mask=loss_mask[:,1:].contiguous().view(-1).float(); loss=(per*mask).sum()/mask.sum().clamp_min(1.0)
            else: loss=per.mean()
        return {'loss':loss,'logits':logits}
    def num_parameters(self): return sum(p.numel() for p in self.parameters())
    def save_config(self,path):
        import json; open(path,'w').write(json.dumps(asdict(self.config),indent=2))