File size: 5,661 Bytes
441dc44
 
 
 
 
 
 
 
 
 
a6a57d5
441dc44
a6a57d5
 
441dc44
 
a6a57d5
441dc44
a6a57d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38b7e42
 
441dc44
 
 
 
a6a57d5
441dc44
a6a57d5
 
 
441dc44
a6a57d5
 
441dc44
 
 
 
a6a57d5
 
 
 
38b7e42
a6a57d5
 
441dc44
 
 
 
 
 
 
 
 
 
 
 
 
a6a57d5
38b7e42
a6a57d5
441dc44
a6a57d5
 
 
441dc44
 
 
 
a6a57d5
441dc44
 
 
 
 
 
 
a6a57d5
 
 
 
 
 
 
 
 
 
441dc44
 
 
a6a57d5
441dc44
 
 
 
 
a6a57d5
441dc44
 
 
38b7e42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin
from .configuration_dwarf import DwarfConfig

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return (x.float() * rms).to(x.dtype) * self.scale

class RotaryEmbedding(nn.Module):
    def __init__(self, head_dim, max_seq_len, theta=10000.0):
        super().__init__()
        assert head_dim % 2 == 0
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.theta = theta
        self.cos_cache = None
        self.sin_cache = None
        self._max = 0
    def _build_cache(self, seq_len, device):
        inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim))
        t = torch.arange(seq_len, device=device).float()
        freqs = torch.outer(t, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.cos_cache = emb.cos()[None, None]
        self.sin_cache = emb.sin()[None, None]
        self._max = seq_len
    @staticmethod
    def _rotate_half(x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat([-x2, x1], dim=-1)
    def forward(self, q, k):
        T = q.size(2)
        if self.cos_cache is None or T > self._max or self.cos_cache.device != q.device:
            self._build_cache(max(T, self.max_seq_len), q.device)
        cos = self.cos_cache[:, :, :T, :]
        sin = self.sin_cache[:, :, :T, :]
        q = q * cos + self._rotate_half(q) * sin
        k = k * cos + self._rotate_half(k) * sin
        return q, k

class GQAAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads
        self.n_groups = config.n_heads // config.n_kv_heads
        self.head_dim = config.head_dim
        self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=True)
        self.k_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=True)
        self.v_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=True)
        self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False)
        self.rope = RotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta)
    def forward(self, x):
        B, T, _ = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        q, k = self.rope(q, k)
        if self.n_groups > 1:
            k = k.repeat_interleave(self.n_groups, dim=1)
            v = v.repeat_interleave(self.n_groups, dim=1)
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
        return self.o_proj(out)

class SwiGLUFFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
    def forward(self, x):
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

class DwarfBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm_attn = RMSNorm(config.d_model, config.norm_eps)
        self.attn = GQAAttention(config)
        self.norm_ffn = RMSNorm(config.d_model, config.norm_eps)
        self.ffn = SwiGLUFFN(config)
    def forward(self, x):
        x = x + self.attn(self.norm_attn(x))
        x = x + self.ffn(self.norm_ffn(x))
        return x

class DwarfForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = DwarfConfig
    _tied_weights_keys = ["lm_head.weight"]
    def __init__(self, config):
        super().__init__(config)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
        self.layers = nn.ModuleList([DwarfBlock(config) for _ in range(config.n_layers)])
        self.norm = RMSNorm(config.d_model, config.norm_eps)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.post_init()
    def tie_weights(self, **kwargs):
        self.lm_head.weight = self.embed_tokens.weight
    def get_input_embeddings(self):
        return self.embed_tokens
    def set_input_embeddings(self, value):
        self.embed_tokens = value
    def get_output_embeddings(self):
        return self.lm_head
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings
    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x)
        logits = self.lm_head(self.norm(x))
        loss = None
        if labels is not None:
            loss = F.cross_entropy(
                logits[:, :-1].contiguous().view(-1, logits.size(-1)),
                labels[:, 1:].contiguous().view(-1), ignore_index=-100)
        from transformers.modeling_outputs import CausalLMOutput
        return CausalLMOutput(loss=loss, logits=logits)
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}