File size: 9,081 Bytes
ca731b9
 
 
 
 
 
 
 
 
 
 
 
 
 
c488635
 
 
 
 
 
 
 
 
 
 
ca731b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c488635
ca731b9
 
 
 
 
 
 
 
4706a47
ca731b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4706a47
 
 
 
ca731b9
 
 
 
 
 
 
 
 
 
 
 
 
 
4706a47
a08f903
ca731b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a08f903
ca731b9
 
a08f903
4706a47
 
ca731b9
4706a47
ca731b9
a08f903
 
ca731b9
 
a08f903
 
 
 
 
 
ca731b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional, Tuple, Dict, Any
import math

class NebulaConfig(PretrainedConfig):
    model_type = "nebula"
    def __init__(self, dim=1280, n_layers=14, n_heads=10, n_kv_heads=10, vocab_size=60729,
                 multiple_of=256, ffn_dim_multiplier=8/3, norm_eps=1e-5, max_seq_len=2048,
                 dropout=0.1, use_cache=True, **kwargs):
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.vocab_size = vocab_size
        self.multiple_of = multiple_of
        self.ffn_dim_multiplier = ffn_dim_multiplier
        self.norm_eps = norm_eps
        self.max_seq_len = max_seq_len
        self.dropout = dropout
        self.use_cache = use_cache
        super().__init__(**kwargs)

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    def forward(self, x):
        return self._norm(x.float()).type_as(x) * self.weight

class RoPE(nn.Module):
    def __init__(self, config: NebulaConfig):
        super().__init__()
        self.dim = config.dim // config.n_heads
        self.max_seq_len = config.max_seq_len
        # The device will be inferred from the model, so we don't need it in the config
        self._build_cache(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    def _build_cache(self, device, base=10000):
        theta = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim))
        t = torch.arange(self.max_seq_len, device=device, dtype=theta.dtype)
        freqs = torch.einsum("i,j->ij", t, theta)
        self.register_buffer('cos_cached', freqs.cos(), persistent=False)
        self.register_buffer('sin_cached', freqs.sin(), persistent=False)
    def forward(self, x: torch.Tensor, start_pos: int = 0):
        seq_len = x.shape[-2]
        cos = self.cos_cached[start_pos : start_pos + seq_len]
        sin = self.sin_cached[start_pos : start_pos + seq_len]
        x1 = x[..., : self.dim // 2]
        x2 = x[..., self.dim // 2 :]
        rotated_x1 = x1 * cos - x2 * sin
        rotated_x2 = x1 * sin + x2 * cos
        return torch.cat([rotated_x1, rotated_x2], dim=-1).type_as(x)

class SwiGLU(nn.Module):
    def __init__(self, config: NebulaConfig):
        super().__init__()
        hidden_dim = int(config.dim * config.ffn_dim_multiplier)
        hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
        self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
        self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class Attention(nn.Module):
    def __init__(self, config: NebulaConfig):
        super().__init__()
        self.config = config
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads
        self.head_dim = config.dim // config.n_heads
        self.n_rep = self.n_heads // config.n_kv_heads
        self.wq = nn.Linear(config.dim, self.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(self.n_heads * self.head_dim, config.dim, bias=False)
        self.rope = RoPE(config)
    def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
        bs, n_kv_heads, seq_len_kv, head_dim = x.shape
        if self.n_rep == 1: return x
        return x.unsqueeze(3).expand(bs, n_kv_heads, seq_len_kv, self.n_rep, head_dim).reshape(bs, self.n_heads, seq_len_kv, head_dim)
    def forward(self, x: torch.Tensor, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, attention_mask: Optional[torch.Tensor] = None):
        bs, seq_len_q, _ = x.shape
        start_pos = past_key_values[0].shape[2] if past_key_values is not None else 0
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bs, seq_len_q, self.n_heads, self.head_dim).transpose(1, 2)
        xk = xk.view(bs, seq_len_q, self.n_kv_heads, self.head_dim).transpose(1, 2)
        xv = xv.view(bs, seq_len_q, self.n_kv_heads, self.head_dim).transpose(1, 2)
        xq = self.rope(xq, start_pos=start_pos)
        xk = self.rope(xk, start_pos=start_pos)
        if past_key_values is not None:
            past_k, past_v = past_key_values
            xk = torch.cat([past_k, xk], dim=2)
            xv = torch.cat([past_v, xv], dim=2)
        present_key_values = (xk, xv) if use_cache else None
        xk_rep, xv_rep = self.repeat_kv(xk), self.repeat_kv(xv)
        output = F.scaled_dot_product_attention(xq, xk_rep, xv_rep, attn_mask=attention_mask)
        output = output.transpose(1, 2).contiguous().view(bs, seq_len_q, -1)
        return self.wo(output), present_key_values

class DecoderBlock(nn.Module):
    def __init__(self, config: NebulaConfig):
        super().__init__()
        self.attention = Attention(config)
        self.feed_forward = SwiGLU(config)
        self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.dropout = nn.Dropout(config.dropout)
        self.attention.wo.is_residual_output = True
        self.feed_forward.w2.is_residual_output = True
    def forward(self, x: torch.Tensor, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, attention_mask: Optional[torch.Tensor] = None):
        attn_out, present_kv = self.attention(self.attention_norm(x), past_key_values=past_key_values, use_cache=use_cache, attention_mask=attention_mask)
        h = x + self.dropout(attn_out)
        ff_out = self.feed_forward(self.ffn_norm(h))
        out = h + self.dropout(ff_out)
        return out, present_kv

class NebulaForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = NebulaConfig
    def __init__(self, config: NebulaConfig):
        super().__init__(config)
        self.model = nn.ModuleDict({"tok_embeddings": nn.Embedding(config.vocab_size, config.dim),
                                   "layers": nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layers)]),
                                   "norm": RMSNorm(config.dim, eps=config.norm_eps),
                                   "output": nn.Linear(config.dim, config.vocab_size, bias=False)})
        self.dropout = nn.Dropout(config.dropout)
        self.model.tok_embeddings.weight = self.model.output.weight
        self.post_init()
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if hasattr(module, 'is_residual_output'): torch.nn.init.normal_(module.weight, mean=0.0, std=(0.02 / math.sqrt(2 * self.config.n_layers)))
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, use_cache: Optional[bool] = None, labels: Optional[torch.Tensor] = None, **kwargs) -> CausalLMOutputWithPast:
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        x = self.dropout(self.model.tok_embeddings(input_ids))
        present_key_values_list = [] if use_cache else None
        if past_key_values is None and use_cache:
            past_key_values = tuple([None] * self.config.n_layers)
        for i, layer in enumerate(self.model.layers):
            past_kv = past_key_values[i]
            x, present_kv = layer(x, past_key_values=past_kv, use_cache=use_cache, attention_mask=attention_mask)
            if use_cache and present_key_values_list is not None:
                present_key_values_list.append(present_kv)
        logits = self.model.output(self.model.norm(x))
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1))
        return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=tuple(present_key_values_list) if present_key_values_list else None)
    def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Dict[str, Any]:
        if past_key_values:
            input_ids = input_ids[:, -1:]
        return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), "attention_mask": attention_mask}