File size: 13,575 Bytes
ae96e35 12621a6 ae96e35 12621a6 ae96e35 12621a6 ae96e35 5446460 ae96e35 5446460 ae96e35 5446460 ae96e35 5446460 ae96e35 5446460 ae96e35 5446460 ae96e35 5446460 ae96e35 5446460 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 | import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
class SparrowConfig(PretrainedConfig):
model_type = "sparrow"
def __init__(
self,
hidden_size: int = 512,
num_hidden_layers: int = 8,
num_attention_heads: int = 16,
num_key_value_heads: Optional[int] = None,
max_seq_len: int = 512,
attention_bias: bool = False,
flash_attn: bool = True,
vocab_size: int = 32000,
hidden_dim: Optional[int] = None,
intermediate_dim: int = 2048,
norm_eps: float = 1e-5,
mlp_bias: bool = False,
dropout: float = 0.0,
**kwargs,
):
super().__init__(**kwargs)
# attention args
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
self.max_seq_len = max_seq_len
self.attention_bias = attention_bias
self.flash_attn = flash_attn
# mlp args
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim if hidden_dim is not None else hidden_size
self.intermediate_dim = intermediate_dim
self.norm_eps = norm_eps
self.mlp_bias = mlp_bias
self.dropout = dropout
## RoPE - from https://arxiv.org/pdf/2104.09864v5
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q*cos) + (rotate_half(q)*sin)
k_embed = (k*cos) + (rotate_half(k)*sin)
return q_embed, k_embed
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=2048):
super(RotaryEmbedding, self).__init__()
self.hidden_size = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).float().unsqueeze(1)
freqs = t @ inv_freq.unsqueeze(0)
freqs = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", freqs.cos())
self.register_buffer("sin_cached", freqs.sin())
def forward(self, q, k):
cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
return apply_rotate_pos_emb(q, k, cos, sin)
## RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float=1.0e-6):
super(RMSNorm, self).__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def normalize(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self.normalize(x).type_as(x)
return output * self.weight
def repeat_kv(x, n_rep):
batch, length, num_key_value_heads, head_dim = x.shape
if n_rep == 1:
return x
x = x[:, :, :, None, :].expand(batch, length, num_key_value_heads, n_rep, head_dim)
return x.reshape(batch, length, num_key_value_heads * n_rep, head_dim)
## SparrowAttention
class SparrowAttention(nn.Module):
'''
'''
def __init__(self, config: SparrowConfig=None):
super(SparrowAttention, self).__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.num_attention_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_attention_heads)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.vocab_size = config.vocab_size
self.dropout = config.dropout
self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
self.wq = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=self.config.attention_bias)
self.wk = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.attention_bias)
self.wv = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.attention_bias)
self.wo = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=self.config.attention_bias)
self.k_cache, self.v_cache = None, None
self.attention_dropout = nn.Dropout(self.dropout)
self.residual_dropout = nn.Dropout(self.dropout)
def forward(self, x: torch.Tensor, use_kv_cache=False):
b, s = x.shape[:2]
if use_kv_cache and self.eval():
if self.k_cache is None or self.k_cache.shape[1] != s-1:
q, k, v = self.wq(x), self.wk(x), self.wv(x)
else:
token = x[:, -1:, :]
q = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
k = torch.cat((self.k_cache, self.wk(token)), dim=1)
v = torch.cat((self.v_cache, self.wv(token)), dim=1)
self.k_cache, self.v_cache = k, v
else:
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.view(b, s, self.num_attention_heads, self.head_dim)
k = k.view(b, s, self.num_key_value_heads, self.head_dim)
v = v.view(b, s, self.num_key_value_heads, self.head_dim)
q, k = self.rotary_emb(q, k)
k, v = repeat_kv(k, self.num_key_value_groups), repeat_kv(v, self.num_key_value_groups)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if self.config.flash_attn:
output = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True)
else:
mask = torch.full((1, 1, self.config.max_seq_len, self.config.max_seq_len), float("-inf"), device=x.device)
mask = torch.triu(mask, diagonal=1)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = scores + mask[:, :, :s, :s]
scores = F.softmax(scores.float(), dim=-1).type_as(q)
scores = self.attention_dropout(scores)
output = torch.matmul(scores, v)
output = output.transpose(1, 2).contiguous().view(b, s, -1)
output = self.wo(output)
output = self.residual_dropout(output)
return output
class SparrowLinear(nn.Module):
def __init__(self, config: SparrowConfig=None):
super(SparrowLinear, self).__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_dim = config.intermediate_dim
self.gate = nn.Linear(self.hidden_size, self.intermediate_dim, bias=self.config.mlp_bias)
self.up = nn.Linear(self.hidden_size, self.intermediate_dim, bias=self.config.mlp_bias)
self.out = nn.Linear(self.intermediate_dim, self.hidden_size, bias=self.config.mlp_bias)
def forward(self, x):
return self.out(F.silu(self.gate(x)) * self.up(x))
class SparrowDecoderLayer(nn.Module):
def __init__(self, config: SparrowConfig=None, layer_idx: int=None):
super(SparrowDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
self.attention = SparrowAttention(config=config)
self.linear = SparrowLinear(config=config)
self.input_norm = RMSNorm(dim=config.hidden_size)
self.pos_attn_norm = RMSNorm(dim=config.hidden_size)
self.layer_idx = layer_idx
def forward(self, x, use_kv_cache):
residual = x
x = self.input_norm(x)
residual, x = x, self.attention(x=x, use_kv_cache=use_kv_cache) + residual
x = self.linear(self.pos_attn_norm(x))
x = x + residual
return x
class SparrowModel(PreTrainedModel):
config_class = SparrowConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.vocab_size = self.config.vocab_size
self.num_hidden_layers = self.config.num_hidden_layers
self.token_embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
self.dropout = nn.Dropout(self.config.dropout)
self.decoder = nn.ModuleList()
for layer_idx in range(self.num_hidden_layers):
self.decoder.append(SparrowDecoderLayer(config=self.config, layer_idx=layer_idx))
self.norm = RMSNorm(dim=self.config.hidden_size)
self.apply(self.weights_init)
def weights_init(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def forward(self, input_ids, use_kv_cache=False):
x = self.dropout(self.token_embedding(input_ids))
for idx, layer in enumerate(self.decoder):
x = layer(x=x, use_kv_cache=use_kv_cache)
return self.norm(x)
class SparrowModelForCausalLM(SparrowModel):
def __init__(self, config):
super().__init__(config)
self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=self.config.mlp_bias)
self.token_embedding.weight = self.output.weight
self.loss = None
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.num_hidden_layers))
def forward(self, input_ids, labels=None, use_kv_cache=False):
x = super().forward(input_ids, use_kv_cache)
if labels is not None:
logits = self.output(x)
self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
else:
logits = self.output(x[:, [-1], :])
self.loss = None
return CausalLMOutputWithPast(self.loss, logits)
@torch.no_grad()
def generate(self, input_ids, eos=1, max_new_tokens=50, temperature=0.7, top_k=None, repetition_penalty=1.,
use_kv_cache=True, use_beam_search=False, beam_size=3):
s = input_ids.shape[1]
if use_beam_search:
sequences = [(input_ids, 0)] # List of (sequence, cumulative log probability)
for _ in range(max_new_tokens - 1):
all_candidates = []
for seq, score in sequences:
inference_res = self(seq, labels=None, use_kv_cache=use_kv_cache)
logits = inference_res.logits[:, -1, :]
if repetition_penalty != 1.0:
for token in set(seq.tolist()[0]):
logits[:, token] /= repetition_penalty
logits = logits / temperature if temperature > 0 else logits
probs = F.log_softmax(logits, dim=-1)
top_log_prob, idx_next = torch.topk(probs, beam_size, dim=-1)
for i in range(beam_size):
next_seq = torch.cat((seq, idx_next[:, i].unsqueeze(1)), dim=1)
next_score = score + top_log_prob[:, i].item()
all_candidates.append((next_seq, next_score))
sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_size]
if all(seq[0][:, -1].item() == eos for seq in sequences):
break
best_seq = sequences[0][0]
return best_seq.tolist()[0][s:]
# Greedy search (default)
generated_tokens = []
while len(generated_tokens) < max_new_tokens - 1:
inference_res = self(input_ids, labels=None, use_kv_cache=use_kv_cache)
logits = inference_res.logits[:, -1, :]
if repetition_penalty != 1.0:
for token in set(input_ids.tolist()[0]):
logits[:, token] /= repetition_penalty
if temperature == 0.0:
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
if idx_next.item() == eos:
break
input_ids = torch.cat((input_ids, idx_next), dim=1)
generated_tokens.append(idx_next.item())
return generated_tokens |