klemenk commited on
Commit
3ebd72a
·
verified ·
1 Parent(s): 3c98b6b

Sync modeling_auristream.py from TuKoResearch/AuriStream100M_40Pred_librilight_200k

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +594 -0
modeling_auristream.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream sequence model definition.
3
+ """
4
+
5
+ import math
6
+ import inspect
7
+ import random
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import numpy as np
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
14
+ from transformers import PreTrainedModel
15
+ from .configuration_auristream import AuriStreamConfig
16
+
17
+
18
+ class AuriStream(PreTrainedModel):
19
+ config_class = AuriStreamConfig
20
+
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+ self.config = config
24
+
25
+ # if use_rope is in the config and false, initialize a wpe layer in transformer
26
+ if hasattr(config, 'use_rope') and not config.use_rope:
27
+ self.transformer = nn.ModuleDict(dict(
28
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
29
+ wpe = nn.Embedding(config.seq_len, config.n_embd),
30
+ drop = nn.Dropout(config.dropout),
31
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
32
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
33
+ ))
34
+ else:
35
+ self.transformer = nn.ModuleDict(dict(
36
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
37
+ drop = nn.Dropout(config.dropout),
38
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
39
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
40
+ ))
41
+
42
+ # check if n_pred_steps is defined in the config, this is the number of linear heads for prediction
43
+ if hasattr(config, 'n_pred_steps'):
44
+ self.future_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(config.n_pred_steps - 1)])
45
+ else:
46
+ self.future_heads = None
47
+
48
+ self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
49
+
50
+ # init all weights
51
+ self.apply(self._init_weights)
52
+ # apply special scaled init to the residual projections, per GPT-2 paper
53
+ for pn, p in self.named_parameters():
54
+ if pn.endswith('c_proj.weight'):
55
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
56
+
57
+ def get_num_params(self, non_embedding=True):
58
+ """
59
+ Return the number of parameters in the model.
60
+ For non-embedding count (default), the position embeddings get subtracted.
61
+ The token embeddings would too, except due to the parameter sharing these
62
+ params are actually used as weights in the final layer, so we include them.
63
+ """
64
+ n_params = sum(p.numel() for p in self.parameters())
65
+ return n_params
66
+
67
+ def _init_weights(self, module):
68
+ if isinstance(module, nn.Linear):
69
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
70
+ if module.bias is not None:
71
+ torch.nn.init.zeros_(module.bias)
72
+ elif isinstance(module, nn.Embedding):
73
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
+
75
+ def forward(self, seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
+ """
77
+ Input: coch: torch.Tensor of shape (b, t)
78
+ tgt_coch: torch.Tensor of shape (b, t) or None
79
+ """
80
+
81
+ # forward the GPT model itself
82
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
83
+
84
+ # if wpe exists in self.transformer apply leanred positional embedding
85
+ if hasattr(self.transformer, 'wpe'):
86
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
87
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
88
+ x = self.transformer.drop(tok_emb + pos_emb)
89
+ else:
90
+ x = self.transformer.drop(tok_emb)
91
+
92
+ all_hidden_states = []
93
+ for block_idx, block in enumerate(self.transformer.h):
94
+ # Forward the block
95
+ all_hidden_states.append(x)
96
+ if up_until_layer is not None and block_idx == up_until_layer:
97
+ break
98
+ x = block(x)
99
+
100
+ # append the last hidden state if we did not exit early
101
+ if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
102
+ all_hidden_states.append(x)
103
+
104
+ if output_hidden_states and not output_logits:
105
+ model_output = BaseModelOutput(
106
+ last_hidden_state=x,
107
+ hidden_states=all_hidden_states,
108
+ )
109
+ return model_output
110
+
111
+ x = self.transformer.ln_f(x)
112
+ logits = self.coch_head(x)
113
+
114
+ if tgt is not None:
115
+
116
+ if output_logits:
117
+ all_logits = [logits]
118
+
119
+ loss = F.cross_entropy(
120
+ logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
121
+ )
122
+
123
+ # If we have more than one future head, compute the loss for each head
124
+ if self.future_heads is not None:
125
+ for i, head in enumerate(self.future_heads):
126
+ future_logits = head(x[:, :-(i+1)])
127
+ loss += F.cross_entropy(
128
+ future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
129
+ )
130
+ if output_logits:
131
+ all_logits.append(future_logits)
132
+ # divide loss by number of future heads
133
+ loss = loss / (len(self.future_heads) + 1)
134
+
135
+ if return_dict:
136
+ if output_logits:
137
+ if output_hidden_states:
138
+ model_output = CausalLMOutput(
139
+ loss=loss,
140
+ logits=all_logits,
141
+ hidden_states=all_hidden_states,
142
+ )
143
+ else:
144
+ model_output = CausalLMOutput(
145
+ loss=loss,
146
+ logits=all_logits,
147
+ )
148
+ else:
149
+ if output_hidden_states:
150
+ model_output = CausalLMOutput(
151
+ loss=loss,
152
+ logits=logits,
153
+ hidden_states=all_hidden_states,
154
+ )
155
+ else:
156
+ model_output = CausalLMOutput(
157
+ loss=loss,
158
+ logits=logits,
159
+ )
160
+ return model_output
161
+
162
+ return logits, loss
163
+
164
+ return logits, None
165
+
166
+ def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
167
+ top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
168
+ """
169
+ Samples an integer from the distribution of logits
170
+ Parameters:
171
+ logits (torch.FloatTensor): The logits of the distribution
172
+ temp (float): The temperature of the sampling, if 0.0, then argmax is used
173
+ top_k (int): The number of top k tokens to consider during sampling
174
+ top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
175
+ Returns:
176
+ torch.LongTensor: The sampled integer
177
+ """
178
+ # If temperature is 0.0, use argmax
179
+ if temperature == 0.0:
180
+ return torch.argmax(logits, dim=-1)
181
+
182
+ # Apply temperature
183
+ logits = logits / temperature
184
+
185
+ # Apply top-k filtering if specified
186
+ if top_k is not None:
187
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
188
+ logits[logits < v[..., [-1]]] = -float('Inf')
189
+
190
+ # Apply top-p (nucleus) filtering if specified
191
+ if top_p is not None:
192
+ # Sort the logits in descending order
193
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
194
+ # Compute the sorted softmax probabilities
195
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
196
+ # Compute the cumulative probabilities
197
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
198
+ # Create a mask for tokens to remove
199
+ sorted_indices_to_remove = cumulative_probs > top_p
200
+ # Shift the mask right to keep at least one token
201
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
202
+ sorted_indices_to_remove[..., 0] = 0
203
+ # Scatter the mask back to the original indices
204
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
205
+ logits[indices_to_remove] = -float('Inf')
206
+
207
+ # Compute softmax probabilities
208
+ probs = F.softmax(logits, dim=-1)
209
+ # Flatten probabilities to (batch_size * sequence_length, vocab_size)
210
+ flat_probs = probs.view(-1, probs.size(-1))
211
+ # Sample from the distribution
212
+ sampled = torch.multinomial(flat_probs, num_samples=1)
213
+ # Reshape to original shape except for the last dimension
214
+ sampled = sampled.view(*logits.shape[:-1])
215
+ return sampled
216
+
217
+ @torch.no_grad()
218
+ def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
219
+ top_k=500, top_p=0.5, seed=None):
220
+ """
221
+ Parameters:
222
+ seq: torch.Tensor of shape (b, t, n_freq_bins)
223
+ Input cochleagram to use for generation
224
+ n_tokens: int
225
+ Number of time bins to predict
226
+ temp: float
227
+ Temperature for sampling logits
228
+ seed: int
229
+ Random seed for sampling
230
+
231
+ Returns:
232
+ pred_coch: torch.Tensor of shape (b, t, n_freq_bins)
233
+ The predicted cochleagram
234
+ all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins)
235
+ The logits for each time step
236
+ all_embs: (optional if return_embs is not None) list of torch.Tensor
237
+ The embeddings for each transformer block
238
+ """
239
+
240
+ # Set seed if provided
241
+ if seed is not None:
242
+ random.seed(seed)
243
+ np.random.seed(seed)
244
+ torch.manual_seed(seed)
245
+
246
+ # make a list of logits to return
247
+ all_logits = []
248
+ device = seq.device
249
+
250
+ # grab shape of the cochleagram
251
+ b, t = seq.size()
252
+
253
+ # TODO: double check this works then delete the block bellow:
254
+ # pass the given input through the model to get the predictions and cache
255
+ # the k and v values for each transformer block in the process
256
+ # pos = torch.arange(0, t, dtype=torch.long, device=device)
257
+ # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
258
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
259
+ # x = self.transformer.drop(tok_emb + pos_emb)
260
+
261
+ #### Embed conditioning sequence into KV cache
262
+
263
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
264
+ # if wpe exists in self.transformer apply leanred positional embedding
265
+ if hasattr(self.transformer, 'wpe'):
266
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
267
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
268
+ x = self.transformer.drop(tok_emb + pos_emb)
269
+ else:
270
+ x = self.transformer.drop(tok_emb)
271
+
272
+ # Initialize list to store k and v for each transformer block
273
+ k_list = []
274
+ v_list = []
275
+ for block_idx, block in enumerate(self.transformer.h):
276
+ # Pass through the transformer block, and store k and v
277
+ x, k, v = block(x, return_kv=True)
278
+ k_list.append(k)
279
+ v_list.append(v)
280
+ # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head)
281
+ k_cache = torch.stack(k_list, dim=0)
282
+ v_cache = torch.stack(v_list, dim=0)
283
+ # Pass through the final layer norm
284
+ x = self.transformer.ln_f(x)
285
+
286
+ # First prediction of the model is the decoding of the last time bin
287
+ logits = self.coch_head(x[:, [-1]])
288
+ predictions = [self.sample_logits(logits, temperature=temp)]
289
+ all_logits.append(logits)
290
+
291
+ ### Predict future tokens
292
+
293
+ # Now we pass the last time bin through the model to predict the next time bin
294
+ # we subtract 1 from max_new_tokens because we already predicted the first time bin
295
+ # using the last embedding of the input
296
+ for i in range(n_tokens-1):
297
+
298
+ # TODO: double check this works then delete the block bellow:
299
+ # # Get the emb and pos embedding of just the last token
300
+ # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
301
+ # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
302
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
303
+ # x = self.transformer.drop(tok_emb + pos_emb)
304
+
305
+ # Get the emb and pos embedding of just the last token
306
+ tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
307
+ # if wpe exists in self.transformer apply leanred positional embedding
308
+ if hasattr(self.transformer, 'wpe'):
309
+ pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
310
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
311
+ x = self.transformer.drop(tok_emb + pos_emb)
312
+ else:
313
+ x = self.transformer.drop(tok_emb)
314
+
315
+ # Pass through transformer block
316
+ k_list = []
317
+ v_list = []
318
+ for block_idx, block in enumerate(self.transformer.h):
319
+ x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
320
+ k_list.append(k)
321
+ v_list.append(v)
322
+ x = self.transformer.ln_f(x)
323
+ # create the cache with the new embeddings
324
+ k_cache = torch.stack(k_list, dim=0)
325
+ v_cache = torch.stack(v_list, dim=0)
326
+ # predict next time bin
327
+ logits = self.coch_head(x)
328
+ predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
329
+ all_logits.append(logits)
330
+
331
+ pred_coch = torch.cat(predictions, dim=1)
332
+ all_logits = torch.cat(all_logits, dim=1)
333
+
334
+ return pred_coch, all_logits
335
+
336
+
337
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
338
+ # start with all of the candidate parameters
339
+ param_dict = {pn: p for pn, p in self.named_parameters()}
340
+ # filter out those that do not require grad
341
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
342
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
343
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
344
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
345
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
346
+ optim_groups = [
347
+ {'params': decay_params, 'weight_decay': weight_decay},
348
+ {'params': nodecay_params, 'weight_decay': 0.0}
349
+ ]
350
+ num_decay_params = sum(p.numel() for p in decay_params)
351
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
352
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
353
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
354
+ # Create AdamW optimizer and use the fused version if it is available
355
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
356
+ use_fused = fused_available and device_type == 'cuda'
357
+ extra_args = dict(fused=True) if use_fused else dict()
358
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
359
+ print(f"using fused AdamW: {use_fused}")
360
+
361
+ return optimizer
362
+
363
+ def estimate_mfu(self, fwdbwd_per_iter, T, dt, gpu_type='A40'):
364
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
365
+ # first estimate the number of flops we do per iteration.
366
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
367
+ N = self.unsharded_param_count
368
+ cfg = self.config
369
+ L, H, Q = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head
370
+ # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
371
+ flops_per_token = 6*N + 12*L*H*Q*T
372
+ flops_per_fwdbwd = flops_per_token * T
373
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
374
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
375
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
376
+
377
+ # grab promised flops based on GPU type
378
+ if gpu_type == 'A40':
379
+ flops_promised = 149.7e12 # A40 GPU bfloat16 peak flops is 149.7 TFLOPS
380
+ elif gpu_type == 'A100':
381
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
382
+ elif gpu_type == 'H100':
383
+ flops_promised = 756e12 # H100 GPU bfloat16 peak flops is 756 TFLOPS
384
+ elif gpu_type == 'TPUv4':
385
+ flops_promised = 275e12
386
+ elif gpu_type == 'TPUv5e':
387
+ flops_promised = 197e12
388
+
389
+ mfu = flops_achieved / flops_promised
390
+ return mfu
391
+
392
+
393
+ #########################################################
394
+ ##### Layer Definitions #####
395
+ #########################################################
396
+
397
+
398
+ class Block(nn.Module):
399
+
400
+ def __init__(self, config):
401
+ super().__init__()
402
+ self.attn = CausalSelfAttention(config)
403
+ self.mlp = MLP(config)
404
+ self.attn_scale = 1.0 # (1 / (2 * config.n_layer)**0.5)
405
+ self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
406
+ self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
407
+
408
+ def forward(self, x, return_kv=False, k_cache=None, v_cache=None):
409
+ # If we are given a key and value cache, we will use the pre-computed values to minimize
410
+ # the computation cost
411
+ if k_cache is not None and v_cache is not None:
412
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
413
+ x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache)
414
+ x = x + x_attn
415
+ x = x + self.mlp(self.norm2(x))
416
+ return x, k, v
417
+ # We might want to encode the caches of a whole block of keys and values at once using the
418
+ # fast flash attention impelmentation while still returning the key and value caches
419
+ elif return_kv:
420
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
421
+ x_attn, k, v = self.attn(self.norm1(x), return_kv=True)
422
+ x = x + x_attn
423
+ x = x + self.mlp(self.norm2(x))
424
+ return x, k, v
425
+
426
+ x = x + self.attn_scale * self.attn(self.norm1(x))
427
+ x = x + self.mlp(self.norm2(x))
428
+ return x
429
+
430
+
431
+ class CausalSelfAttention(nn.Module):
432
+
433
+ def __init__(self, config):
434
+ super().__init__()
435
+ self.n_head = config.n_head
436
+ self.n_embd = config.n_embd
437
+ self.head_dim = self.n_embd // self.n_head
438
+ assert self.n_embd % self.n_head == 0
439
+ # key, query, value projections for all heads, but in a batch
440
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
441
+ # output projection
442
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
443
+
444
+ rope_theta = 500000
445
+ if hasattr(config, 'rope_theta') and config.rope_theta is not None:
446
+ rope_theta = config.rope_theta
447
+
448
+ self.rotary = Rotary(self.head_dim, base=rope_theta)
449
+
450
+ if hasattr(config, 'use_rope') and not config.use_rope:
451
+ self.rotary = None
452
+
453
+ def forward(self, x, return_kv=False, return_attn_maps=False):
454
+
455
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
456
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
457
+ qkv = self.c_attn(x)
458
+ q, k, v = qkv.split(self.n_embd, dim=2)
459
+ k = k.view(B, T, self.n_head, self.head_dim)
460
+ q = q.view(B, T, self.n_head, self.head_dim)
461
+ v = v.view(B, T, self.n_head, self.head_dim)
462
+
463
+ if self.rotary is not None:
464
+ cos, sin = self.rotary(q)
465
+ q = apply_rotary_emb(q, cos, sin)
466
+ k = apply_rotary_emb(k, cos, sin)
467
+
468
+ if not return_kv and not return_attn_maps:
469
+ y = F.scaled_dot_product_attention(
470
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
471
+ is_causal=True)
472
+ else:
473
+ # manual implementation of attention
474
+ q = q.transpose(1, 2)
475
+ k = k.transpose(1, 2)
476
+ v = v.transpose(1, 2)
477
+ att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
478
+ mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
479
+ mask = mask.view(1, 1, T, T)
480
+ masked_att = att.masked_fill(mask, float('-inf'))
481
+ # upcast to float32 for numerical stability, as per llama implementation
482
+ masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
483
+ # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
484
+ y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
485
+
486
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
487
+
488
+ # output projection
489
+ y = self.c_proj(y)
490
+
491
+ # return attention maps if requested
492
+ if return_attn_maps:
493
+ return y, F.softmax(att, dim=-1)
494
+
495
+ # return key and value caches if requested
496
+ if return_kv:
497
+ return y, k, v
498
+
499
+ return y
500
+
501
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
502
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
503
+
504
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
505
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
506
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
507
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
508
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
509
+
510
+ # append cached keys and values with new keys and values
511
+ if k_cache is not None:
512
+ k = torch.cat((k_cache, k), dim=2)
513
+ if v_cache is not None:
514
+ v = torch.cat((v_cache, v), dim=2)
515
+
516
+ # manual implementation of attention
517
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
518
+ att = F.softmax(att, dim=-1)
519
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
520
+
521
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
522
+
523
+ # output projection
524
+ y = self.c_proj(y)
525
+
526
+ return y, k, v
527
+
528
+
529
+ class MLP(nn.Module):
530
+
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
534
+ self.gelu = nn.SiLU()
535
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
536
+ self.dropout = nn.Dropout(config.dropout)
537
+
538
+ def forward(self, x):
539
+ x = self.c_fc(x)
540
+ x = self.gelu(x)
541
+ x = self.c_proj(x)
542
+ x = self.dropout(x)
543
+ return x
544
+
545
+
546
+ class Rotary(torch.nn.Module):
547
+ def __init__(self, dim, base=500000, learned=True):
548
+ super().__init__()
549
+ # Compute the base inverse frequencies as before.
550
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
551
+ # If learned is True, register as a parameter; otherwise, as a buffer.
552
+ if learned:
553
+ # Initialize randomly and register as a parameter.
554
+ self.inv_freq = torch.nn.Parameter(inv_freq)
555
+ nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
556
+ else:
557
+ self.register_buffer("inv_freq", inv_freq)
558
+ self.learned = learned # (optional) Save the flag if needed later
559
+
560
+ def forward(self, x):
561
+ seq_len = x.shape[1]
562
+ # Create a tensor of positions.
563
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
564
+ # Outer product to compute angles; this uses the (possibly learnable) frequencies.
565
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
566
+ cos_cached = freqs.cos()
567
+ sin_cached = freqs.sin()
568
+ return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
569
+
570
+ def apply_rotary_emb(x, cos, sin):
571
+ assert x.ndim == 4 # multihead attention expected
572
+ d = x.shape[3] // 2
573
+ x1 = x[..., :d]
574
+ x2 = x[..., d:]
575
+ y1 = x1 * cos + x2 * sin
576
+ y2 = x1 * (-sin) + x2 * cos
577
+ return torch.cat([y1, y2], dim=3)
578
+
579
+
580
+ class RMSNorm(nn.Module):
581
+ """ Root Mean Square Normalization """
582
+ def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
583
+ super().__init__()
584
+ self.eps = eps
585
+ self.weight = nn.Parameter(torch.ones(dim)) if weight else None
586
+
587
+ def _norm(self, x):
588
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
589
+
590
+ def forward(self, x):
591
+ output = self._norm(x.float()).type_as(x)
592
+ if self.weight is not None:
593
+ return output * self.weight
594
+ return output