klemenk commited on
Commit
4e01699
·
verified ·
1 Parent(s): e7e99d4

Sync modeling_auristream.py from TuKoResearch/AuriStream200M_100Pred_librilight_200k

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +621 -0
modeling_auristream.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream sequence model definition.
3
+ """
4
+
5
+ import math
6
+ import inspect
7
+ import random
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import numpy as np
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
14
+ from transformers import PreTrainedModel
15
+ from .configuration_auristream import AuriStreamConfig
16
+
17
+
18
+ class AuriStream(PreTrainedModel):
19
+ config_class = AuriStreamConfig
20
+
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+ self.config = config
24
+
25
+ # if use_rope is in the config and false, initialize a wpe layer in transformer
26
+ if hasattr(config, 'use_rope') and not config.use_rope:
27
+ self.transformer = nn.ModuleDict(dict(
28
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
29
+ wpe = nn.Embedding(config.seq_len, config.n_embd),
30
+ drop = nn.Dropout(config.dropout),
31
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
32
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
33
+ ))
34
+ else:
35
+ self.transformer = nn.ModuleDict(dict(
36
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
37
+ drop = nn.Dropout(config.dropout),
38
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
39
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
40
+ ))
41
+
42
+ # check if n_pred_steps is defined in the config, this is the number of linear heads for prediction
43
+ if hasattr(config, 'n_pred_steps'):
44
+ self.future_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(config.n_pred_steps - 1)])
45
+ else:
46
+ self.future_heads = None
47
+
48
+ self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
49
+
50
+ # init all weights
51
+ self.apply(self._init_weights)
52
+ # apply special scaled init to the residual projections, per GPT-2 paper
53
+ for pn, p in self.named_parameters():
54
+ if pn.endswith('c_proj.weight'):
55
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
56
+
57
+ def get_num_params(self, non_embedding=True):
58
+ """
59
+ Return the number of parameters in the model.
60
+ For non-embedding count (default), the position embeddings get subtracted.
61
+ The token embeddings would too, except due to the parameter sharing these
62
+ params are actually used as weights in the final layer, so we include them.
63
+ """
64
+ n_params = sum(p.numel() for p in self.parameters())
65
+ return n_params
66
+
67
+ def _init_weights(self, module):
68
+ if isinstance(module, nn.Linear):
69
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
70
+ if module.bias is not None:
71
+ torch.nn.init.zeros_(module.bias)
72
+ elif isinstance(module, nn.Embedding):
73
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
+
75
+ def forward(self, seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
+ """
77
+ Input: coch: torch.Tensor of shape (b, t)
78
+ tgt_coch: torch.Tensor of shape (b, t) or None
79
+ """
80
+
81
+ # forward the GPT model itself
82
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
83
+
84
+ # if wpe exists in self.transformer apply leanred positional embedding
85
+ if hasattr(self.transformer, 'wpe'):
86
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
87
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
88
+ x = self.transformer.drop(tok_emb + pos_emb)
89
+ else:
90
+ x = self.transformer.drop(tok_emb)
91
+
92
+ all_hidden_states = []
93
+ for block_idx, block in enumerate(self.transformer.h):
94
+ # Forward the block
95
+ all_hidden_states.append(x)
96
+ if up_until_layer is not None and block_idx == up_until_layer:
97
+ break
98
+ x = block(x)
99
+
100
+ # append the last hidden state if we did not exit early
101
+ if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
102
+ all_hidden_states.append(x)
103
+
104
+ if output_hidden_states and not output_logits:
105
+ model_output = BaseModelOutput(
106
+ last_hidden_state=x,
107
+ hidden_states=all_hidden_states,
108
+ )
109
+ return model_output
110
+
111
+ x = self.transformer.ln_f(x)
112
+ logits = self.coch_head(x)
113
+
114
+
115
+ if output_logits:
116
+ all_logits = [logits]
117
+
118
+ if tgt is not None:
119
+ loss = F.cross_entropy(
120
+ logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
121
+ )
122
+
123
+ # If we have more than one future head, compute the loss for each head
124
+ if self.future_heads is not None:
125
+ for i, head in enumerate(self.future_heads):
126
+ future_logits = head(x[:, :-(i+1)])
127
+
128
+ if tgt is not None:
129
+ loss += F.cross_entropy(
130
+ future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
131
+ )
132
+ if output_logits:
133
+ all_logits.append(future_logits)
134
+
135
+ if tgt is not None:
136
+ # divide loss by number of future heads
137
+ loss = loss / (len(self.future_heads) + 1)
138
+
139
+ if return_dict:
140
+ if output_logits:
141
+ if output_hidden_states:
142
+ if tgt is not None:
143
+ model_output = CausalLMOutput(
144
+ loss=loss,
145
+ logits=all_logits,
146
+ hidden_states=all_hidden_states,
147
+ )
148
+ else:
149
+ model_output = CausalLMOutput(
150
+ logits=all_logits,
151
+ hidden_states=all_hidden_states,
152
+ )
153
+ else:
154
+ if tgt is not None:
155
+ model_output = CausalLMOutput(
156
+ loss=loss,
157
+ logits=all_logits,
158
+ )
159
+ else:
160
+ model_output = CausalLMOutput(
161
+ logits=all_logits,
162
+ )
163
+ else:
164
+ if output_hidden_states:
165
+ if tgt is not None:
166
+ model_output = CausalLMOutput(
167
+ loss=loss,
168
+ logits=logits,
169
+ hidden_states=all_hidden_states,
170
+ )
171
+ else:
172
+ model_output = CausalLMOutput(
173
+ logits=logits,
174
+ hidden_states=all_hidden_states,
175
+ )
176
+ else:
177
+ if tgt is not None:
178
+ model_output = CausalLMOutput(
179
+ loss=loss,
180
+ logits=logits,
181
+ )
182
+ else:
183
+ model_output = CausalLMOutput(
184
+ logits=logits,
185
+ )
186
+ return model_output
187
+
188
+ if tgt is not None:
189
+ return logits, loss
190
+
191
+ return logits, None
192
+
193
+ def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
194
+ top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
195
+ """
196
+ Samples an integer from the distribution of logits
197
+ Parameters:
198
+ logits (torch.FloatTensor): The logits of the distribution
199
+ temp (float): The temperature of the sampling, if 0.0, then argmax is used
200
+ top_k (int): The number of top k tokens to consider during sampling
201
+ top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
202
+ Returns:
203
+ torch.LongTensor: The sampled integer
204
+ """
205
+ # If temperature is 0.0, use argmax
206
+ if temperature == 0.0:
207
+ return torch.argmax(logits, dim=-1)
208
+
209
+ # Apply temperature
210
+ logits = logits / temperature
211
+
212
+ # Apply top-k filtering if specified
213
+ if top_k is not None:
214
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
215
+ logits[logits < v[..., [-1]]] = -float('Inf')
216
+
217
+ # Apply top-p (nucleus) filtering if specified
218
+ if top_p is not None:
219
+ # Sort the logits in descending order
220
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
221
+ # Compute the sorted softmax probabilities
222
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
223
+ # Compute the cumulative probabilities
224
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
225
+ # Create a mask for tokens to remove
226
+ sorted_indices_to_remove = cumulative_probs > top_p
227
+ # Shift the mask right to keep at least one token
228
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
229
+ sorted_indices_to_remove[..., 0] = 0
230
+ # Scatter the mask back to the original indices
231
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
232
+ logits[indices_to_remove] = -float('Inf')
233
+
234
+ # Compute softmax probabilities
235
+ probs = F.softmax(logits, dim=-1)
236
+ # Flatten probabilities to (batch_size * sequence_length, vocab_size)
237
+ flat_probs = probs.view(-1, probs.size(-1))
238
+ # Sample from the distribution
239
+ sampled = torch.multinomial(flat_probs, num_samples=1)
240
+ # Reshape to original shape except for the last dimension
241
+ sampled = sampled.view(*logits.shape[:-1])
242
+ return sampled
243
+
244
+ @torch.no_grad()
245
+ def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
246
+ top_k=500, top_p=0.5, seed=None):
247
+ """
248
+ Parameters:
249
+ seq: torch.Tensor of shape (b, t, n_freq_bins)
250
+ Input cochleagram to use for generation
251
+ n_tokens: int
252
+ Number of time bins to predict
253
+ temp: float
254
+ Temperature for sampling logits
255
+ seed: int
256
+ Random seed for sampling
257
+
258
+ Returns:
259
+ pred_coch: torch.Tensor of shape (b, t, n_freq_bins)
260
+ The predicted cochleagram
261
+ all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins)
262
+ The logits for each time step
263
+ all_embs: (optional if return_embs is not None) list of torch.Tensor
264
+ The embeddings for each transformer block
265
+ """
266
+
267
+ # Set seed if provided
268
+ if seed is not None:
269
+ random.seed(seed)
270
+ np.random.seed(seed)
271
+ torch.manual_seed(seed)
272
+
273
+ # make a list of logits to return
274
+ all_logits = []
275
+ device = seq.device
276
+
277
+ # grab shape of the cochleagram
278
+ b, t = seq.size()
279
+
280
+ # TODO: double check this works then delete the block bellow:
281
+ # pass the given input through the model to get the predictions and cache
282
+ # the k and v values for each transformer block in the process
283
+ # pos = torch.arange(0, t, dtype=torch.long, device=device)
284
+ # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
285
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
286
+ # x = self.transformer.drop(tok_emb + pos_emb)
287
+
288
+ #### Embed conditioning sequence into KV cache
289
+
290
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
291
+ # if wpe exists in self.transformer apply leanred positional embedding
292
+ if hasattr(self.transformer, 'wpe'):
293
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
294
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
295
+ x = self.transformer.drop(tok_emb + pos_emb)
296
+ else:
297
+ x = self.transformer.drop(tok_emb)
298
+
299
+ # Initialize list to store k and v for each transformer block
300
+ k_list = []
301
+ v_list = []
302
+ for block_idx, block in enumerate(self.transformer.h):
303
+ # Pass through the transformer block, and store k and v
304
+ x, k, v = block(x, return_kv=True)
305
+ k_list.append(k)
306
+ v_list.append(v)
307
+ # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head)
308
+ k_cache = torch.stack(k_list, dim=0)
309
+ v_cache = torch.stack(v_list, dim=0)
310
+ # Pass through the final layer norm
311
+ x = self.transformer.ln_f(x)
312
+
313
+ # First prediction of the model is the decoding of the last time bin
314
+ logits = self.coch_head(x[:, [-1]])
315
+ predictions = [self.sample_logits(logits, temperature=temp)]
316
+ all_logits.append(logits)
317
+
318
+ ### Predict future tokens
319
+
320
+ # Now we pass the last time bin through the model to predict the next time bin
321
+ # we subtract 1 from max_new_tokens because we already predicted the first time bin
322
+ # using the last embedding of the input
323
+ for i in range(n_tokens-1):
324
+
325
+ # TODO: double check this works then delete the block bellow:
326
+ # # Get the emb and pos embedding of just the last token
327
+ # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
328
+ # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
329
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
330
+ # x = self.transformer.drop(tok_emb + pos_emb)
331
+
332
+ # Get the emb and pos embedding of just the last token
333
+ tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
334
+ # if wpe exists in self.transformer apply leanred positional embedding
335
+ if hasattr(self.transformer, 'wpe'):
336
+ pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
337
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
338
+ x = self.transformer.drop(tok_emb + pos_emb)
339
+ else:
340
+ x = self.transformer.drop(tok_emb)
341
+
342
+ # Pass through transformer block
343
+ k_list = []
344
+ v_list = []
345
+ for block_idx, block in enumerate(self.transformer.h):
346
+ x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
347
+ k_list.append(k)
348
+ v_list.append(v)
349
+ x = self.transformer.ln_f(x)
350
+ # create the cache with the new embeddings
351
+ k_cache = torch.stack(k_list, dim=0)
352
+ v_cache = torch.stack(v_list, dim=0)
353
+ # predict next time bin
354
+ logits = self.coch_head(x)
355
+ predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
356
+ all_logits.append(logits)
357
+
358
+ pred_coch = torch.cat(predictions, dim=1)
359
+ all_logits = torch.cat(all_logits, dim=1)
360
+
361
+ return pred_coch, all_logits
362
+
363
+
364
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
365
+ # start with all of the candidate parameters
366
+ param_dict = {pn: p for pn, p in self.named_parameters()}
367
+ # filter out those that do not require grad
368
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
369
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
370
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
371
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
372
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
373
+ optim_groups = [
374
+ {'params': decay_params, 'weight_decay': weight_decay},
375
+ {'params': nodecay_params, 'weight_decay': 0.0}
376
+ ]
377
+ num_decay_params = sum(p.numel() for p in decay_params)
378
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
379
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
380
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
381
+ # Create AdamW optimizer and use the fused version if it is available
382
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
383
+ use_fused = fused_available and device_type == 'cuda'
384
+ extra_args = dict(fused=True) if use_fused else dict()
385
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
386
+ print(f"using fused AdamW: {use_fused}")
387
+
388
+ return optimizer
389
+
390
+ def estimate_mfu(self, fwdbwd_per_iter, T, dt, gpu_type='A40'):
391
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
392
+ # first estimate the number of flops we do per iteration.
393
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
394
+ N = self.unsharded_param_count
395
+ cfg = self.config
396
+ L, H, Q = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head
397
+ # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
398
+ flops_per_token = 6*N + 12*L*H*Q*T
399
+ flops_per_fwdbwd = flops_per_token * T
400
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
401
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
402
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
403
+
404
+ # grab promised flops based on GPU type
405
+ if gpu_type == 'A40':
406
+ flops_promised = 149.7e12 # A40 GPU bfloat16 peak flops is 149.7 TFLOPS
407
+ elif gpu_type == 'A100':
408
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
409
+ elif gpu_type == 'H100':
410
+ flops_promised = 756e12 # H100 GPU bfloat16 peak flops is 756 TFLOPS
411
+ elif gpu_type == 'TPUv4':
412
+ flops_promised = 275e12
413
+ elif gpu_type == 'TPUv5e':
414
+ flops_promised = 197e12
415
+
416
+ mfu = flops_achieved / flops_promised
417
+ return mfu
418
+
419
+
420
+ #########################################################
421
+ ##### Layer Definitions #####
422
+ #########################################################
423
+
424
+
425
+ class Block(nn.Module):
426
+
427
+ def __init__(self, config):
428
+ super().__init__()
429
+ self.attn = CausalSelfAttention(config)
430
+ self.mlp = MLP(config)
431
+ self.attn_scale = 1.0 # (1 / (2 * config.n_layer)**0.5)
432
+ self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
433
+ self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
434
+
435
+ def forward(self, x, return_kv=False, k_cache=None, v_cache=None):
436
+ # If we are given a key and value cache, we will use the pre-computed values to minimize
437
+ # the computation cost
438
+ if k_cache is not None and v_cache is not None:
439
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
440
+ x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache)
441
+ x = x + x_attn
442
+ x = x + self.mlp(self.norm2(x))
443
+ return x, k, v
444
+ # We might want to encode the caches of a whole block of keys and values at once using the
445
+ # fast flash attention impelmentation while still returning the key and value caches
446
+ elif return_kv:
447
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
448
+ x_attn, k, v = self.attn(self.norm1(x), return_kv=True)
449
+ x = x + x_attn
450
+ x = x + self.mlp(self.norm2(x))
451
+ return x, k, v
452
+
453
+ x = x + self.attn_scale * self.attn(self.norm1(x))
454
+ x = x + self.mlp(self.norm2(x))
455
+ return x
456
+
457
+
458
+ class CausalSelfAttention(nn.Module):
459
+
460
+ def __init__(self, config):
461
+ super().__init__()
462
+ self.n_head = config.n_head
463
+ self.n_embd = config.n_embd
464
+ self.head_dim = self.n_embd // self.n_head
465
+ assert self.n_embd % self.n_head == 0
466
+ # key, query, value projections for all heads, but in a batch
467
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
468
+ # output projection
469
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
470
+
471
+ rope_theta = 500000
472
+ if hasattr(config, 'rope_theta') and config.rope_theta is not None:
473
+ rope_theta = config.rope_theta
474
+
475
+ self.rotary = Rotary(self.head_dim, base=rope_theta)
476
+
477
+ if hasattr(config, 'use_rope') and not config.use_rope:
478
+ self.rotary = None
479
+
480
+ def forward(self, x, return_kv=False, return_attn_maps=False):
481
+
482
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
483
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
484
+ qkv = self.c_attn(x)
485
+ q, k, v = qkv.split(self.n_embd, dim=2)
486
+ k = k.view(B, T, self.n_head, self.head_dim)
487
+ q = q.view(B, T, self.n_head, self.head_dim)
488
+ v = v.view(B, T, self.n_head, self.head_dim)
489
+
490
+ if self.rotary is not None:
491
+ cos, sin = self.rotary(q)
492
+ q = apply_rotary_emb(q, cos, sin)
493
+ k = apply_rotary_emb(k, cos, sin)
494
+
495
+ if not return_kv and not return_attn_maps:
496
+ y = F.scaled_dot_product_attention(
497
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
498
+ is_causal=True)
499
+ else:
500
+ # manual implementation of attention
501
+ q = q.transpose(1, 2)
502
+ k = k.transpose(1, 2)
503
+ v = v.transpose(1, 2)
504
+ att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
505
+ mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
506
+ mask = mask.view(1, 1, T, T)
507
+ masked_att = att.masked_fill(mask, float('-inf'))
508
+ # upcast to float32 for numerical stability, as per llama implementation
509
+ masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
510
+ # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
511
+ y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
512
+
513
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
514
+
515
+ # output projection
516
+ y = self.c_proj(y)
517
+
518
+ # return attention maps if requested
519
+ if return_attn_maps:
520
+ return y, F.softmax(att, dim=-1)
521
+
522
+ # return key and value caches if requested
523
+ if return_kv:
524
+ return y, k, v
525
+
526
+ return y
527
+
528
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
529
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
530
+
531
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
532
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
533
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
534
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
535
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
536
+
537
+ # append cached keys and values with new keys and values
538
+ if k_cache is not None:
539
+ k = torch.cat((k_cache, k), dim=2)
540
+ if v_cache is not None:
541
+ v = torch.cat((v_cache, v), dim=2)
542
+
543
+ # manual implementation of attention
544
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
545
+ att = F.softmax(att, dim=-1)
546
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
547
+
548
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
549
+
550
+ # output projection
551
+ y = self.c_proj(y)
552
+
553
+ return y, k, v
554
+
555
+
556
+ class MLP(nn.Module):
557
+
558
+ def __init__(self, config):
559
+ super().__init__()
560
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
561
+ self.gelu = nn.SiLU()
562
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
563
+ self.dropout = nn.Dropout(config.dropout)
564
+
565
+ def forward(self, x):
566
+ x = self.c_fc(x)
567
+ x = self.gelu(x)
568
+ x = self.c_proj(x)
569
+ x = self.dropout(x)
570
+ return x
571
+
572
+
573
+ class Rotary(torch.nn.Module):
574
+ def __init__(self, dim, base=500000, learned=True):
575
+ super().__init__()
576
+ # Compute the base inverse frequencies as before.
577
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
578
+ # If learned is True, register as a parameter; otherwise, as a buffer.
579
+ if learned:
580
+ # Initialize randomly and register as a parameter.
581
+ self.inv_freq = torch.nn.Parameter(inv_freq)
582
+ nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
583
+ else:
584
+ self.register_buffer("inv_freq", inv_freq)
585
+ self.learned = learned # (optional) Save the flag if needed later
586
+
587
+ def forward(self, x):
588
+ seq_len = x.shape[1]
589
+ # Create a tensor of positions.
590
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
591
+ # Outer product to compute angles; this uses the (possibly learnable) frequencies.
592
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
593
+ cos_cached = freqs.cos()
594
+ sin_cached = freqs.sin()
595
+ return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
596
+
597
+ def apply_rotary_emb(x, cos, sin):
598
+ assert x.ndim == 4 # multihead attention expected
599
+ d = x.shape[3] // 2
600
+ x1 = x[..., :d]
601
+ x2 = x[..., d:]
602
+ y1 = x1 * cos + x2 * sin
603
+ y2 = x1 * (-sin) + x2 * cos
604
+ return torch.cat([y1, y2], dim=3)
605
+
606
+
607
+ class RMSNorm(nn.Module):
608
+ """ Root Mean Square Normalization """
609
+ def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
610
+ super().__init__()
611
+ self.eps = eps
612
+ self.weight = nn.Parameter(torch.ones(dim)) if weight else None
613
+
614
+ def _norm(self, x):
615
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
616
+
617
+ def forward(self, x):
618
+ output = self._norm(x.float()).type_as(x)
619
+ if self.weight is not None:
620
+ return output * self.weight
621
+ return output