klemenk commited on
Commit
9d832b1
·
verified ·
1 Parent(s): 82a43ae

Sync modeling_auristream.py from TuKoResearch/AuriStream100M_40Pred_librilight_200k

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +580 -0
modeling_auristream.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream sequence model definition.
3
+ """
4
+
5
+ import math
6
+ import inspect
7
+ import random
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import numpy as np
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
14
+ from transformers import PreTrainedModel
15
+ from .configuration_auristream import AuriStreamConfig
16
+
17
+
18
+ class AuriStream(PreTrainedModel):
19
+ config_class = AuriStreamConfig
20
+
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+ self.config = config
24
+
25
+ # if use_rope is in the config and false, initialize a wpe layer in transformer
26
+ if hasattr(config, 'use_rope') and not config.use_rope:
27
+ self.transformer = nn.ModuleDict(dict(
28
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
29
+ wpe = nn.Embedding(config.seq_len, config.n_embd),
30
+ drop = nn.Dropout(config.dropout),
31
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
32
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
33
+ ))
34
+ else:
35
+ self.transformer = nn.ModuleDict(dict(
36
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
37
+ drop = nn.Dropout(config.dropout),
38
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
39
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
40
+ ))
41
+
42
+ # check if n_pred_steps is defined in the config, this is the number of linear heads for prediction
43
+ if hasattr(config, 'n_pred_steps'):
44
+ self.future_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(config.n_pred_steps - 1)])
45
+ else:
46
+ self.future_heads = None
47
+
48
+ self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
49
+
50
+ # init all weights
51
+ self.apply(self._init_weights)
52
+ # apply special scaled init to the residual projections, per GPT-2 paper
53
+ for pn, p in self.named_parameters():
54
+ if pn.endswith('c_proj.weight'):
55
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
56
+
57
+ def get_num_params(self, non_embedding=True):
58
+ """
59
+ Return the number of parameters in the model.
60
+ For non-embedding count (default), the position embeddings get subtracted.
61
+ The token embeddings would too, except due to the parameter sharing these
62
+ params are actually used as weights in the final layer, so we include them.
63
+ """
64
+ n_params = sum(p.numel() for p in self.parameters())
65
+ return n_params
66
+
67
+ def _init_weights(self, module):
68
+ if isinstance(module, nn.Linear):
69
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
70
+ if module.bias is not None:
71
+ torch.nn.init.zeros_(module.bias)
72
+ elif isinstance(module, nn.Embedding):
73
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
+
75
+ def forward(self, seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
+ """
77
+ Input: coch: torch.Tensor of shape (b, t)
78
+ tgt_coch: torch.Tensor of shape (b, t) or None
79
+ """
80
+
81
+ # forward the GPT model itself
82
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
83
+
84
+ # if wpe exists in self.transformer apply leanred positional embedding
85
+ if hasattr(self.transformer, 'wpe'):
86
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
87
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
88
+ x = self.transformer.drop(tok_emb + pos_emb)
89
+ else:
90
+ x = self.transformer.drop(tok_emb)
91
+
92
+ all_hidden_states = []
93
+ for block_idx, block in enumerate(self.transformer.h):
94
+ # Forward the block
95
+ all_hidden_states.append(x)
96
+ if up_until_layer is not None and block_idx == up_until_layer:
97
+ break
98
+ x = block(x)
99
+
100
+ # append the last hidden state if we did not exit early
101
+ if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
102
+ all_hidden_states.append(x)
103
+
104
+ if output_hidden_states and not output_logits:
105
+ model_output = BaseModelOutput(
106
+ last_hidden_state=x,
107
+ hidden_states=all_hidden_states,
108
+ )
109
+ return model_output
110
+
111
+ x = self.transformer.ln_f(x)
112
+ logits = self.coch_head(x)
113
+
114
+ if tgt is not None:
115
+
116
+ if output_logits:
117
+ all_logits = [logits]
118
+
119
+ loss = F.cross_entropy(
120
+ logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
121
+ )
122
+
123
+ # If we have more than one future head, compute the loss for each head
124
+ if self.future_heads is not None:
125
+ for i, head in enumerate(self.future_heads):
126
+ future_logits = head(x[:, :-(i+1)])
127
+ loss += F.cross_entropy(
128
+ future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
129
+ )
130
+ if output_logits:
131
+ all_logits.append(future_logits)
132
+ # divide loss by number of future heads
133
+ loss = loss / (len(self.future_heads) + 1)
134
+
135
+ if return_dict:
136
+ if output_logits:
137
+ model_output = CausalLMOutput(
138
+ loss=loss,
139
+ logits=all_logits,
140
+ )
141
+ else:
142
+ model_output = CausalLMOutput(
143
+ loss=loss,
144
+ logits=logits,
145
+ )
146
+ return model_output
147
+
148
+ return logits, loss
149
+
150
+ return logits, None
151
+
152
+ def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
153
+ top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
154
+ """
155
+ Samples an integer from the distribution of logits
156
+ Parameters:
157
+ logits (torch.FloatTensor): The logits of the distribution
158
+ temp (float): The temperature of the sampling, if 0.0, then argmax is used
159
+ top_k (int): The number of top k tokens to consider during sampling
160
+ top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
161
+ Returns:
162
+ torch.LongTensor: The sampled integer
163
+ """
164
+ # If temperature is 0.0, use argmax
165
+ if temperature == 0.0:
166
+ return torch.argmax(logits, dim=-1)
167
+
168
+ # Apply temperature
169
+ logits = logits / temperature
170
+
171
+ # Apply top-k filtering if specified
172
+ if top_k is not None:
173
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
174
+ logits[logits < v[..., [-1]]] = -float('Inf')
175
+
176
+ # Apply top-p (nucleus) filtering if specified
177
+ if top_p is not None:
178
+ # Sort the logits in descending order
179
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
180
+ # Compute the sorted softmax probabilities
181
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
182
+ # Compute the cumulative probabilities
183
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
184
+ # Create a mask for tokens to remove
185
+ sorted_indices_to_remove = cumulative_probs > top_p
186
+ # Shift the mask right to keep at least one token
187
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
188
+ sorted_indices_to_remove[..., 0] = 0
189
+ # Scatter the mask back to the original indices
190
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
191
+ logits[indices_to_remove] = -float('Inf')
192
+
193
+ # Compute softmax probabilities
194
+ probs = F.softmax(logits, dim=-1)
195
+ # Flatten probabilities to (batch_size * sequence_length, vocab_size)
196
+ flat_probs = probs.view(-1, probs.size(-1))
197
+ # Sample from the distribution
198
+ sampled = torch.multinomial(flat_probs, num_samples=1)
199
+ # Reshape to original shape except for the last dimension
200
+ sampled = sampled.view(*logits.shape[:-1])
201
+ return sampled
202
+
203
+ @torch.no_grad()
204
+ def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
205
+ top_k=500, top_p=0.5, seed=None):
206
+ """
207
+ Parameters:
208
+ seq: torch.Tensor of shape (b, t, n_freq_bins)
209
+ Input cochleagram to use for generation
210
+ n_tokens: int
211
+ Number of time bins to predict
212
+ temp: float
213
+ Temperature for sampling logits
214
+ seed: int
215
+ Random seed for sampling
216
+
217
+ Returns:
218
+ pred_coch: torch.Tensor of shape (b, t, n_freq_bins)
219
+ The predicted cochleagram
220
+ all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins)
221
+ The logits for each time step
222
+ all_embs: (optional if return_embs is not None) list of torch.Tensor
223
+ The embeddings for each transformer block
224
+ """
225
+
226
+ # Set seed if provided
227
+ if seed is not None:
228
+ random.seed(seed)
229
+ np.random.seed(seed)
230
+ torch.manual_seed(seed)
231
+
232
+ # make a list of logits to return
233
+ all_logits = []
234
+ device = seq.device
235
+
236
+ # grab shape of the cochleagram
237
+ b, t = seq.size()
238
+
239
+ # TODO: double check this works then delete the block bellow:
240
+ # pass the given input through the model to get the predictions and cache
241
+ # the k and v values for each transformer block in the process
242
+ # pos = torch.arange(0, t, dtype=torch.long, device=device)
243
+ # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
244
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
245
+ # x = self.transformer.drop(tok_emb + pos_emb)
246
+
247
+ #### Embed conditioning sequence into KV cache
248
+
249
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
250
+ # if wpe exists in self.transformer apply leanred positional embedding
251
+ if hasattr(self.transformer, 'wpe'):
252
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
253
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
254
+ x = self.transformer.drop(tok_emb + pos_emb)
255
+ else:
256
+ x = self.transformer.drop(tok_emb)
257
+
258
+ # Initialize list to store k and v for each transformer block
259
+ k_list = []
260
+ v_list = []
261
+ for block_idx, block in enumerate(self.transformer.h):
262
+ # Pass through the transformer block, and store k and v
263
+ x, k, v = block(x, return_kv=True)
264
+ k_list.append(k)
265
+ v_list.append(v)
266
+ # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head)
267
+ k_cache = torch.stack(k_list, dim=0)
268
+ v_cache = torch.stack(v_list, dim=0)
269
+ # Pass through the final layer norm
270
+ x = self.transformer.ln_f(x)
271
+
272
+ # First prediction of the model is the decoding of the last time bin
273
+ logits = self.coch_head(x[:, [-1]])
274
+ predictions = [self.sample_logits(logits, temperature=temp)]
275
+ all_logits.append(logits)
276
+
277
+ ### Predict future tokens
278
+
279
+ # Now we pass the last time bin through the model to predict the next time bin
280
+ # we subtract 1 from max_new_tokens because we already predicted the first time bin
281
+ # using the last embedding of the input
282
+ for i in range(n_tokens-1):
283
+
284
+ # TODO: double check this works then delete the block bellow:
285
+ # # Get the emb and pos embedding of just the last token
286
+ # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
287
+ # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
288
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
289
+ # x = self.transformer.drop(tok_emb + pos_emb)
290
+
291
+ # Get the emb and pos embedding of just the last token
292
+ tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
293
+ # if wpe exists in self.transformer apply leanred positional embedding
294
+ if hasattr(self.transformer, 'wpe'):
295
+ pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
296
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
297
+ x = self.transformer.drop(tok_emb + pos_emb)
298
+ else:
299
+ x = self.transformer.drop(tok_emb)
300
+
301
+ # Pass through transformer block
302
+ k_list = []
303
+ v_list = []
304
+ for block_idx, block in enumerate(self.transformer.h):
305
+ x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
306
+ k_list.append(k)
307
+ v_list.append(v)
308
+ x = self.transformer.ln_f(x)
309
+ # create the cache with the new embeddings
310
+ k_cache = torch.stack(k_list, dim=0)
311
+ v_cache = torch.stack(v_list, dim=0)
312
+ # predict next time bin
313
+ logits = self.coch_head(x)
314
+ predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
315
+ all_logits.append(logits)
316
+
317
+ pred_coch = torch.cat(predictions, dim=1)
318
+ all_logits = torch.cat(all_logits, dim=1)
319
+
320
+ return pred_coch, all_logits
321
+
322
+
323
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
324
+ # start with all of the candidate parameters
325
+ param_dict = {pn: p for pn, p in self.named_parameters()}
326
+ # filter out those that do not require grad
327
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
328
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
329
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
330
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
331
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
332
+ optim_groups = [
333
+ {'params': decay_params, 'weight_decay': weight_decay},
334
+ {'params': nodecay_params, 'weight_decay': 0.0}
335
+ ]
336
+ num_decay_params = sum(p.numel() for p in decay_params)
337
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
338
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
339
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
340
+ # Create AdamW optimizer and use the fused version if it is available
341
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
342
+ use_fused = fused_available and device_type == 'cuda'
343
+ extra_args = dict(fused=True) if use_fused else dict()
344
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
345
+ print(f"using fused AdamW: {use_fused}")
346
+
347
+ return optimizer
348
+
349
+ def estimate_mfu(self, fwdbwd_per_iter, T, dt, gpu_type='A40'):
350
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
351
+ # first estimate the number of flops we do per iteration.
352
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
353
+ N = self.unsharded_param_count
354
+ cfg = self.config
355
+ L, H, Q = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head
356
+ # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
357
+ flops_per_token = 6*N + 12*L*H*Q*T
358
+ flops_per_fwdbwd = flops_per_token * T
359
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
360
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
361
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
362
+
363
+ # grab promised flops based on GPU type
364
+ if gpu_type == 'A40':
365
+ flops_promised = 149.7e12 # A40 GPU bfloat16 peak flops is 149.7 TFLOPS
366
+ elif gpu_type == 'A100':
367
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
368
+ elif gpu_type == 'H100':
369
+ flops_promised = 756e12 # H100 GPU bfloat16 peak flops is 756 TFLOPS
370
+ elif gpu_type == 'TPUv4':
371
+ flops_promised = 275e12
372
+ elif gpu_type == 'TPUv5e':
373
+ flops_promised = 197e12
374
+
375
+ mfu = flops_achieved / flops_promised
376
+ return mfu
377
+
378
+
379
+ #########################################################
380
+ ##### Layer Definitions #####
381
+ #########################################################
382
+
383
+
384
+ class Block(nn.Module):
385
+
386
+ def __init__(self, config):
387
+ super().__init__()
388
+ self.attn = CausalSelfAttention(config)
389
+ self.mlp = MLP(config)
390
+ self.attn_scale = 1.0 # (1 / (2 * config.n_layer)**0.5)
391
+ self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
392
+ self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
393
+
394
+ def forward(self, x, return_kv=False, k_cache=None, v_cache=None):
395
+ # If we are given a key and value cache, we will use the pre-computed values to minimize
396
+ # the computation cost
397
+ if k_cache is not None and v_cache is not None:
398
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
399
+ x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache)
400
+ x = x + x_attn
401
+ x = x + self.mlp(self.norm2(x))
402
+ return x, k, v
403
+ # We might want to encode the caches of a whole block of keys and values at once using the
404
+ # fast flash attention impelmentation while still returning the key and value caches
405
+ elif return_kv:
406
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
407
+ x_attn, k, v = self.attn(self.norm1(x), return_kv=True)
408
+ x = x + x_attn
409
+ x = x + self.mlp(self.norm2(x))
410
+ return x, k, v
411
+
412
+ x = x + self.attn_scale * self.attn(self.norm1(x))
413
+ x = x + self.mlp(self.norm2(x))
414
+ return x
415
+
416
+
417
+ class CausalSelfAttention(nn.Module):
418
+
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.n_head = config.n_head
422
+ self.n_embd = config.n_embd
423
+ self.head_dim = self.n_embd // self.n_head
424
+ assert self.n_embd % self.n_head == 0
425
+ # key, query, value projections for all heads, but in a batch
426
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
427
+ # output projection
428
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
429
+
430
+ rope_theta = 500000
431
+ if hasattr(config, 'rope_theta') and config.rope_theta is not None:
432
+ rope_theta = config.rope_theta
433
+
434
+ self.rotary = Rotary(self.head_dim, base=rope_theta)
435
+
436
+ if hasattr(config, 'use_rope') and not config.use_rope:
437
+ self.rotary = None
438
+
439
+ def forward(self, x, return_kv=False, return_attn_maps=False):
440
+
441
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
442
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
443
+ qkv = self.c_attn(x)
444
+ q, k, v = qkv.split(self.n_embd, dim=2)
445
+ k = k.view(B, T, self.n_head, self.head_dim)
446
+ q = q.view(B, T, self.n_head, self.head_dim)
447
+ v = v.view(B, T, self.n_head, self.head_dim)
448
+
449
+ if self.rotary is not None:
450
+ cos, sin = self.rotary(q)
451
+ q = apply_rotary_emb(q, cos, sin)
452
+ k = apply_rotary_emb(k, cos, sin)
453
+
454
+ if not return_kv and not return_attn_maps:
455
+ y = F.scaled_dot_product_attention(
456
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
457
+ is_causal=True)
458
+ else:
459
+ # manual implementation of attention
460
+ q = q.transpose(1, 2)
461
+ k = k.transpose(1, 2)
462
+ v = v.transpose(1, 2)
463
+ att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
464
+ mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
465
+ mask = mask.view(1, 1, T, T)
466
+ masked_att = att.masked_fill(mask, float('-inf'))
467
+ # upcast to float32 for numerical stability, as per llama implementation
468
+ masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
469
+ # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
470
+ y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
471
+
472
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
473
+
474
+ # output projection
475
+ y = self.c_proj(y)
476
+
477
+ # return attention maps if requested
478
+ if return_attn_maps:
479
+ return y, F.softmax(att, dim=-1)
480
+
481
+ # return key and value caches if requested
482
+ if return_kv:
483
+ return y, k, v
484
+
485
+ return y
486
+
487
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
488
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
489
+
490
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
491
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
492
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
493
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
494
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
495
+
496
+ # append cached keys and values with new keys and values
497
+ if k_cache is not None:
498
+ k = torch.cat((k_cache, k), dim=2)
499
+ if v_cache is not None:
500
+ v = torch.cat((v_cache, v), dim=2)
501
+
502
+ # manual implementation of attention
503
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
504
+ att = F.softmax(att, dim=-1)
505
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
506
+
507
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
508
+
509
+ # output projection
510
+ y = self.c_proj(y)
511
+
512
+ return y, k, v
513
+
514
+
515
+ class MLP(nn.Module):
516
+
517
+ def __init__(self, config):
518
+ super().__init__()
519
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
520
+ self.gelu = nn.SiLU()
521
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
522
+ self.dropout = nn.Dropout(config.dropout)
523
+
524
+ def forward(self, x):
525
+ x = self.c_fc(x)
526
+ x = self.gelu(x)
527
+ x = self.c_proj(x)
528
+ x = self.dropout(x)
529
+ return x
530
+
531
+
532
+ class Rotary(torch.nn.Module):
533
+ def __init__(self, dim, base=500000, learned=True):
534
+ super().__init__()
535
+ # Compute the base inverse frequencies as before.
536
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
537
+ # If learned is True, register as a parameter; otherwise, as a buffer.
538
+ if learned:
539
+ # Initialize randomly and register as a parameter.
540
+ self.inv_freq = torch.nn.Parameter(inv_freq)
541
+ nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
542
+ else:
543
+ self.register_buffer("inv_freq", inv_freq)
544
+ self.learned = learned # (optional) Save the flag if needed later
545
+
546
+ def forward(self, x):
547
+ seq_len = x.shape[1]
548
+ # Create a tensor of positions.
549
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
550
+ # Outer product to compute angles; this uses the (possibly learnable) frequencies.
551
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
552
+ cos_cached = freqs.cos()
553
+ sin_cached = freqs.sin()
554
+ return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
555
+
556
+ def apply_rotary_emb(x, cos, sin):
557
+ assert x.ndim == 4 # multihead attention expected
558
+ d = x.shape[3] // 2
559
+ x1 = x[..., :d]
560
+ x2 = x[..., d:]
561
+ y1 = x1 * cos + x2 * sin
562
+ y2 = x1 * (-sin) + x2 * cos
563
+ return torch.cat([y1, y2], dim=3)
564
+
565
+
566
+ class RMSNorm(nn.Module):
567
+ """ Root Mean Square Normalization """
568
+ def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
569
+ super().__init__()
570
+ self.eps = eps
571
+ self.weight = nn.Parameter(torch.ones(dim)) if weight else None
572
+
573
+ def _norm(self, x):
574
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
575
+
576
+ def forward(self, x):
577
+ output = self._norm(x.float()).type_as(x)
578
+ if self.weight is not None:
579
+ return output * self.weight
580
+ return output