File size: 9,081 Bytes
ca731b9 c488635 ca731b9 c488635 ca731b9 4706a47 ca731b9 4706a47 ca731b9 4706a47 a08f903 ca731b9 a08f903 ca731b9 a08f903 4706a47 ca731b9 4706a47 ca731b9 a08f903 ca731b9 a08f903 ca731b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional, Tuple, Dict, Any
import math
class NebulaConfig(PretrainedConfig):
model_type = "nebula"
def __init__(self, dim=1280, n_layers=14, n_heads=10, n_kv_heads=10, vocab_size=60729,
multiple_of=256, ffn_dim_multiplier=8/3, norm_eps=1e-5, max_seq_len=2048,
dropout=0.1, use_cache=True, **kwargs):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.multiple_of = multiple_of
self.ffn_dim_multiplier = ffn_dim_multiplier
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.dropout = dropout
self.use_cache = use_cache
super().__init__(**kwargs)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
class RoPE(nn.Module):
def __init__(self, config: NebulaConfig):
super().__init__()
self.dim = config.dim // config.n_heads
self.max_seq_len = config.max_seq_len
# The device will be inferred from the model, so we don't need it in the config
self._build_cache(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
def _build_cache(self, device, base=10000):
theta = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim))
t = torch.arange(self.max_seq_len, device=device, dtype=theta.dtype)
freqs = torch.einsum("i,j->ij", t, theta)
self.register_buffer('cos_cached', freqs.cos(), persistent=False)
self.register_buffer('sin_cached', freqs.sin(), persistent=False)
def forward(self, x: torch.Tensor, start_pos: int = 0):
seq_len = x.shape[-2]
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
x1 = x[..., : self.dim // 2]
x2 = x[..., self.dim // 2 :]
rotated_x1 = x1 * cos - x2 * sin
rotated_x2 = x1 * sin + x2 * cos
return torch.cat([rotated_x1, rotated_x2], dim=-1).type_as(x)
class SwiGLU(nn.Module):
def __init__(self, config: NebulaConfig):
super().__init__()
hidden_dim = int(config.dim * config.ffn_dim_multiplier)
hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Attention(nn.Module):
def __init__(self, config: NebulaConfig):
super().__init__()
self.config = config
self.n_heads = config.n_heads
self.n_kv_heads = config.n_kv_heads
self.head_dim = config.dim // config.n_heads
self.n_rep = self.n_heads // config.n_kv_heads
self.wq = nn.Linear(config.dim, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(self.n_heads * self.head_dim, config.dim, bias=False)
self.rope = RoPE(config)
def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
bs, n_kv_heads, seq_len_kv, head_dim = x.shape
if self.n_rep == 1: return x
return x.unsqueeze(3).expand(bs, n_kv_heads, seq_len_kv, self.n_rep, head_dim).reshape(bs, self.n_heads, seq_len_kv, head_dim)
def forward(self, x: torch.Tensor, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, attention_mask: Optional[torch.Tensor] = None):
bs, seq_len_q, _ = x.shape
start_pos = past_key_values[0].shape[2] if past_key_values is not None else 0
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bs, seq_len_q, self.n_heads, self.head_dim).transpose(1, 2)
xk = xk.view(bs, seq_len_q, self.n_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(bs, seq_len_q, self.n_kv_heads, self.head_dim).transpose(1, 2)
xq = self.rope(xq, start_pos=start_pos)
xk = self.rope(xk, start_pos=start_pos)
if past_key_values is not None:
past_k, past_v = past_key_values
xk = torch.cat([past_k, xk], dim=2)
xv = torch.cat([past_v, xv], dim=2)
present_key_values = (xk, xv) if use_cache else None
xk_rep, xv_rep = self.repeat_kv(xk), self.repeat_kv(xv)
output = F.scaled_dot_product_attention(xq, xk_rep, xv_rep, attn_mask=attention_mask)
output = output.transpose(1, 2).contiguous().view(bs, seq_len_q, -1)
return self.wo(output), present_key_values
class DecoderBlock(nn.Module):
def __init__(self, config: NebulaConfig):
super().__init__()
self.attention = Attention(config)
self.feed_forward = SwiGLU(config)
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.dropout = nn.Dropout(config.dropout)
self.attention.wo.is_residual_output = True
self.feed_forward.w2.is_residual_output = True
def forward(self, x: torch.Tensor, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, attention_mask: Optional[torch.Tensor] = None):
attn_out, present_kv = self.attention(self.attention_norm(x), past_key_values=past_key_values, use_cache=use_cache, attention_mask=attention_mask)
h = x + self.dropout(attn_out)
ff_out = self.feed_forward(self.ffn_norm(h))
out = h + self.dropout(ff_out)
return out, present_kv
class NebulaForCausalLM(PreTrainedModel, GenerationMixin):
config_class = NebulaConfig
def __init__(self, config: NebulaConfig):
super().__init__(config)
self.model = nn.ModuleDict({"tok_embeddings": nn.Embedding(config.vocab_size, config.dim),
"layers": nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layers)]),
"norm": RMSNorm(config.dim, eps=config.norm_eps),
"output": nn.Linear(config.dim, config.vocab_size, bias=False)})
self.dropout = nn.Dropout(config.dropout)
self.model.tok_embeddings.weight = self.model.output.weight
self.post_init()
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'is_residual_output'): torch.nn.init.normal_(module.weight, mean=0.0, std=(0.02 / math.sqrt(2 * self.config.n_layers)))
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, use_cache: Optional[bool] = None, labels: Optional[torch.Tensor] = None, **kwargs) -> CausalLMOutputWithPast:
use_cache = use_cache if use_cache is not None else self.config.use_cache
x = self.dropout(self.model.tok_embeddings(input_ids))
present_key_values_list = [] if use_cache else None
if past_key_values is None and use_cache:
past_key_values = tuple([None] * self.config.n_layers)
for i, layer in enumerate(self.model.layers):
past_kv = past_key_values[i]
x, present_kv = layer(x, past_key_values=past_kv, use_cache=use_cache, attention_mask=attention_mask)
if use_cache and present_key_values_list is not None:
present_key_values_list.append(present_kv)
logits = self.model.output(self.model.norm(x))
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=tuple(present_key_values_list) if present_key_values_list else None)
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Dict[str, Any]:
if past_key_values:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), "attention_mask": attention_mask}
|