klemenk commited on
Commit
3bd534c
·
verified ·
1 Parent(s): 98c4724

Sync modeling_auristream.py from TuKoResearch/AuriStream7BDeep_7k

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +596 -0
modeling_auristream.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream sequence model definition.
3
+ """
4
+
5
+ import math
6
+ import inspect
7
+ import random
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import numpy as np
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
14
+ from transformers import PreTrainedModel
15
+ from .configuration_auristream import AuriStreamConfig
16
+
17
+
18
+ class AuriStream(PreTrainedModel):
19
+ config_class = AuriStreamConfig
20
+
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+ self.config = config
24
+
25
+ # if use_rope is in the config and false, initialize a wpe layer in transformer
26
+ if hasattr(config, 'use_rope') and not config.use_rope:
27
+ self.transformer = nn.ModuleDict(dict(
28
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
29
+ wpe = nn.Embedding(config.seq_len, config.n_embd),
30
+ drop = nn.Dropout(config.dropout),
31
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
32
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
33
+ ))
34
+ else:
35
+ self.transformer = nn.ModuleDict(dict(
36
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
37
+ drop = nn.Dropout(config.dropout),
38
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
39
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
40
+ ))
41
+
42
+ # check if n_pred_steps is defined in the config, this is the number of linear heads for prediction
43
+ if hasattr(config, 'n_pred_steps'):
44
+ self.future_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(config.n_pred_steps - 1)])
45
+ else:
46
+ self.future_heads = None
47
+
48
+ self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
49
+
50
+ # # init all weights
51
+ # self.apply(self._init_weights)
52
+ # # apply special scaled init to the residual projections, per GPT-2 paper
53
+ # for pn, p in self.named_parameters():
54
+ # if pn.endswith('c_proj.weight'):
55
+ # torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
56
+
57
+ self.dwa = None
58
+ if config.skip_connections:
59
+ self.dwa = DWA(config.n_layer + 1)
60
+
61
+ def get_num_params(self, non_embedding=True):
62
+ """
63
+ Return the number of parameters in the model.
64
+ For non-embedding count (default), the position embeddings get subtracted.
65
+ The token embeddings would too, except due to the parameter sharing these
66
+ params are actually used as weights in the final layer, so we include them.
67
+ """
68
+ n_params = sum(p.numel() for p in self.parameters())
69
+ return n_params
70
+
71
+ def _init_weights(self, module):
72
+ if isinstance(module, nn.Linear):
73
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
+ if module.bias is not None:
75
+ torch.nn.init.zeros_(module.bias)
76
+ elif isinstance(module, nn.Embedding):
77
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
78
+
79
+ def forward(self, seq, tgt=None, output_hidden_states=False, return_dict=False, up_until_layer=None):
80
+ """
81
+ Input: coch: torch.Tensor of shape (b, t)
82
+ tgt_coch: torch.Tensor of shape (b, t) or None
83
+ """
84
+
85
+ # forward the GPT model itself
86
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
87
+
88
+ # if wpe exists in self.transformer apply leanred positional embedding
89
+ if hasattr(self.transformer, 'wpe'):
90
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
91
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
92
+ x = self.transformer.drop(tok_emb + pos_emb)
93
+ else:
94
+ x = self.transformer.drop(tok_emb)
95
+
96
+ # Initialize embedding accumulators
97
+ if self.dwa is not None:
98
+ x = self.dwa.init_accumulators(x)
99
+
100
+ all_hidden_states = []
101
+ for block_idx, block in enumerate(self.transformer.h):
102
+ # Forward the block
103
+ all_hidden_states.append(x)
104
+ if up_until_layer is not None and block_idx == up_until_layer:
105
+ break
106
+ x = block(x)
107
+ if self.dwa is not None:
108
+ x = self.dwa(x)
109
+
110
+
111
+ # append the last hidden state if we did not exit early
112
+ if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
113
+ all_hidden_states.append(x)
114
+
115
+ if output_hidden_states:
116
+ model_output = BaseModelOutput(
117
+ last_hidden_state=x,
118
+ hidden_states=all_hidden_states,
119
+ )
120
+ return model_output
121
+
122
+ x = self.transformer.ln_f(x)
123
+ logits = self.coch_head(x)
124
+
125
+ if tgt is not None:
126
+ loss = F.cross_entropy(
127
+ logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
128
+ )
129
+
130
+ # If we have more than one future head, compute the loss for each head
131
+ if self.future_heads is not None:
132
+ for i, head in enumerate(self.future_heads):
133
+ future_logits = head(x[:, :-(i+1)])
134
+ loss += F.cross_entropy(
135
+ future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
136
+ )
137
+ # divide loss by number of future heads
138
+ loss = loss / (len(self.future_heads) + 1)
139
+
140
+ if return_dict:
141
+ model_output = CausalLMOutput(
142
+ loss=loss,
143
+ logits=logits,
144
+ )
145
+ return model_output
146
+
147
+ return logits, loss
148
+
149
+ return logits, None
150
+
151
+ def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
152
+ top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
153
+ """
154
+ Samples an integer from the distribution of logits
155
+ Parameters:
156
+ logits (torch.FloatTensor): The logits of the distribution
157
+ temp (float): The temperature of the sampling, if 0.0, then argmax is used
158
+ top_k (int): The number of top k tokens to consider during sampling
159
+ top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
160
+ Returns:
161
+ torch.LongTensor: The sampled integer
162
+ """
163
+ # If temperature is 0.0, use argmax
164
+ if temperature == 0.0:
165
+ return torch.argmax(logits, dim=-1)
166
+
167
+ # Apply temperature
168
+ logits = logits / temperature
169
+
170
+ # Apply top-k filtering if specified
171
+ if top_k is not None:
172
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
173
+ logits[logits < v[..., [-1]]] = -float('Inf')
174
+
175
+ # Apply top-p (nucleus) filtering if specified
176
+ if top_p is not None:
177
+ # Sort the logits in descending order
178
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
179
+ # Compute the sorted softmax probabilities
180
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
181
+ # Compute the cumulative probabilities
182
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
183
+ # Create a mask for tokens to remove
184
+ sorted_indices_to_remove = cumulative_probs > top_p
185
+ # Shift the mask right to keep at least one token
186
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
187
+ sorted_indices_to_remove[..., 0] = 0
188
+ # Scatter the mask back to the original indices
189
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
190
+ logits[indices_to_remove] = -float('Inf')
191
+
192
+ # Compute softmax probabilities
193
+ probs = F.softmax(logits, dim=-1)
194
+ # Flatten probabilities to (batch_size * sequence_length, vocab_size)
195
+ flat_probs = probs.view(-1, probs.size(-1))
196
+ # Sample from the distribution
197
+ sampled = torch.multinomial(flat_probs, num_samples=1)
198
+ # Reshape to original shape except for the last dimension
199
+ sampled = sampled.view(*logits.shape[:-1])
200
+ return sampled
201
+
202
+ @torch.no_grad()
203
+ def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
204
+ top_k=500, top_p=0.5, seed=None):
205
+ """
206
+ Parameters:
207
+ seq: torch.Tensor of shape (b, t, n_freq_bins)
208
+ Input cochleagram to use for generation
209
+ n_tokens: int
210
+ Number of time bins to predict
211
+ temp: float
212
+ Temperature for sampling logits
213
+ seed: int
214
+ Random seed for sampling
215
+
216
+ Returns:
217
+ pred_coch: torch.Tensor of shape (b, t, n_freq_bins)
218
+ The predicted cochleagram
219
+ all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins)
220
+ The logits for each time step
221
+ all_embs: (optional if return_embs is not None) list of torch.Tensor
222
+ The embeddings for each transformer block
223
+ """
224
+
225
+ # Set seed if provided
226
+ if seed is not None:
227
+ random.seed(seed)
228
+ np.random.seed(seed)
229
+ torch.manual_seed(seed)
230
+
231
+ # make a list of logits to return
232
+ all_logits = []
233
+ device = seq.device
234
+
235
+ # grab shape of the cochleagram
236
+ b, t = seq.size()
237
+
238
+ # TODO: double check this works then delete the block bellow:
239
+ # pass the given input through the model to get the predictions and cache
240
+ # the k and v values for each transformer block in the process
241
+ # pos = torch.arange(0, t, dtype=torch.long, device=device)
242
+ # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
243
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
244
+ # x = self.transformer.drop(tok_emb + pos_emb)
245
+
246
+ #### Embed conditioning sequence into KV cache
247
+
248
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
249
+ # if wpe exists in self.transformer apply leanred positional embedding
250
+ if hasattr(self.transformer, 'wpe'):
251
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
252
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
253
+ x = self.transformer.drop(tok_emb + pos_emb)
254
+ else:
255
+ x = self.transformer.drop(tok_emb)
256
+
257
+ # Initialize list to store k and v for each transformer block
258
+ k_list = []
259
+ v_list = []
260
+ for block_idx, block in enumerate(self.transformer.h):
261
+ # Pass through the transformer block, and store k and v
262
+ x, k, v = block(x, return_kv=True)
263
+ k_list.append(k)
264
+ v_list.append(v)
265
+ # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head)
266
+ k_cache = torch.stack(k_list, dim=0)
267
+ v_cache = torch.stack(v_list, dim=0)
268
+ # Pass through the final layer norm
269
+ x = self.transformer.ln_f(x)
270
+
271
+ # First prediction of the model is the decoding of the last time bin
272
+ logits = self.coch_head(x[:, [-1]])
273
+ predictions = [self.sample_logits(logits, temperature=temp)]
274
+ all_logits.append(logits)
275
+
276
+ ### Predict future tokens
277
+
278
+ # Now we pass the last time bin through the model to predict the next time bin
279
+ # we subtract 1 from max_new_tokens because we already predicted the first time bin
280
+ # using the last embedding of the input
281
+ for i in range(n_tokens-1):
282
+
283
+ # TODO: double check this works then delete the block bellow:
284
+ # # Get the emb and pos embedding of just the last token
285
+ # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
286
+ # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
287
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
288
+ # x = self.transformer.drop(tok_emb + pos_emb)
289
+
290
+ # Get the emb and pos embedding of just the last token
291
+ tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
292
+ # if wpe exists in self.transformer apply leanred positional embedding
293
+ if hasattr(self.transformer, 'wpe'):
294
+ pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
295
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
296
+ x = self.transformer.drop(tok_emb + pos_emb)
297
+ else:
298
+ x = self.transformer.drop(tok_emb)
299
+
300
+ # Pass through transformer block
301
+ k_list = []
302
+ v_list = []
303
+ for block_idx, block in enumerate(self.transformer.h):
304
+ x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
305
+ k_list.append(k)
306
+ v_list.append(v)
307
+ x = self.transformer.ln_f(x)
308
+ # create the cache with the new embeddings
309
+ k_cache = torch.stack(k_list, dim=0)
310
+ v_cache = torch.stack(v_list, dim=0)
311
+ # predict next time bin
312
+ logits = self.coch_head(x)
313
+ predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
314
+ all_logits.append(logits)
315
+
316
+ pred_coch = torch.cat(predictions, dim=1)
317
+ all_logits = torch.cat(all_logits, dim=1)
318
+
319
+ return pred_coch, all_logits
320
+
321
+
322
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
323
+ # start with all of the candidate parameters
324
+ param_dict = {pn: p for pn, p in self.named_parameters()}
325
+ # filter out those that do not require grad
326
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
327
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
328
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
329
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
330
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
331
+ optim_groups = [
332
+ {'params': decay_params, 'weight_decay': weight_decay},
333
+ {'params': nodecay_params, 'weight_decay': 0.0}
334
+ ]
335
+ num_decay_params = sum(p.numel() for p in decay_params)
336
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
337
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
338
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
339
+ # Create AdamW optimizer and use the fused version if it is available
340
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
341
+ use_fused = fused_available and device_type == 'cuda'
342
+ extra_args = dict(fused=True) if use_fused else dict()
343
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
344
+ print(f"using fused AdamW: {use_fused}")
345
+
346
+ return optimizer
347
+
348
+ def estimate_mfu(self, fwdbwd_per_iter, T, dt, gpu_type='A40'):
349
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
350
+ # first estimate the number of flops we do per iteration.
351
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
352
+ N = self.unsharded_param_count
353
+ cfg = self.config
354
+ L, H, Q = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head
355
+ # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
356
+ flops_per_token = 6*N + 12*L*H*Q*T
357
+ flops_per_fwdbwd = flops_per_token * T
358
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
359
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
360
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
361
+
362
+ # grab promised flops based on GPU type
363
+ if gpu_type == 'A40':
364
+ flops_promised = 149.7e12 # A40 GPU bfloat16 peak flops is 149.7 TFLOPS
365
+ elif gpu_type == 'A100':
366
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
367
+ elif gpu_type == 'H100':
368
+ flops_promised = 756e12 # H100 GPU bfloat16 peak flops is 756 TFLOPS
369
+ elif gpu_type == 'TPUv4':
370
+ flops_promised = 275e12
371
+ elif gpu_type == 'TPUv5e':
372
+ flops_promised = 197e12
373
+
374
+ mfu = flops_achieved / flops_promised
375
+ return mfu
376
+
377
+
378
+ #########################################################
379
+ ##### Layer Definitions #####
380
+ #########################################################
381
+
382
+
383
+ class Block(nn.Module):
384
+
385
+ def __init__(self, config):
386
+ super().__init__()
387
+ self.attn = CausalSelfAttention(config)
388
+ self.mlp = MLP(config)
389
+ self.attn_scale = 1.0 # (1 / (2 * config.n_layer)**0.5)
390
+ self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
391
+ self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
392
+
393
+ def forward(self, x, return_kv=False, k_cache=None, v_cache=None):
394
+ # If we are given a key and value cache, we will use the pre-computed values to minimize
395
+ # the computation cost
396
+ if k_cache is not None and v_cache is not None:
397
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
398
+ x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache)
399
+ x = x + x_attn
400
+ x = x + self.mlp(self.norm2(x))
401
+ return x, k, v
402
+ # We might want to encode the caches of a whole block of keys and values at once using the
403
+ # fast flash attention impelmentation while still returning the key and value caches
404
+ elif return_kv:
405
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
406
+ x_attn, k, v = self.attn(self.norm1(x), return_kv=True)
407
+ x = x + x_attn
408
+ x = x + self.mlp(self.norm2(x))
409
+ return x, k, v
410
+
411
+ x = x + self.attn_scale * self.attn(self.norm1(x))
412
+ x = x + self.mlp(self.norm2(x))
413
+ return x
414
+
415
+
416
+ class CausalSelfAttention(nn.Module):
417
+
418
+ def __init__(self, config):
419
+ super().__init__()
420
+ self.n_head = config.n_head
421
+ self.n_embd = config.n_embd
422
+ self.head_dim = self.n_embd // self.n_head
423
+ assert self.n_embd % self.n_head == 0
424
+ # key, query, value projections for all heads, but in a batch
425
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
426
+ # output projection
427
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
428
+
429
+ rope_theta = 10000
430
+ if hasattr(config, 'rope_theta') and config.rope_theta is not None:
431
+ rope_theta = config.rope_theta
432
+
433
+ self.rotary = Rotary(self.head_dim, base=rope_theta)
434
+
435
+ if hasattr(config, 'use_rope') and not config.use_rope:
436
+ self.rotary = None
437
+
438
+ def forward(self, x, return_kv=False, return_attn_maps=False):
439
+
440
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
441
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
442
+ qkv = self.c_attn(x)
443
+ q, k, v = qkv.split(self.n_embd, dim=2)
444
+ k = k.view(B, T, self.n_head, self.head_dim)
445
+ q = q.view(B, T, self.n_head, self.head_dim)
446
+ v = v.view(B, T, self.n_head, self.head_dim)
447
+
448
+ if self.rotary is not None:
449
+ cos, sin = self.rotary(q)
450
+ q = apply_rotary_emb(q, cos, sin)
451
+ k = apply_rotary_emb(k, cos, sin)
452
+
453
+ if not return_kv and not return_attn_maps:
454
+ y = F.scaled_dot_product_attention(
455
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
456
+ is_causal=True)
457
+ else:
458
+ # manual implementation of attention
459
+ q = q.transpose(1, 2)
460
+ k = k.transpose(1, 2)
461
+ v = v.transpose(1, 2)
462
+ att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
463
+ mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
464
+ mask = mask.view(1, 1, T, T)
465
+ masked_att = att.masked_fill(mask, float('-inf'))
466
+ # upcast to float32 for numerical stability, as per llama implementation
467
+ masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
468
+ # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
469
+ y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
470
+
471
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
472
+
473
+ # output projection
474
+ y = self.c_proj(y)
475
+
476
+ # return attention maps if requested
477
+ if return_attn_maps:
478
+ return y, F.softmax(att, dim=-1)
479
+
480
+ # return key and value caches if requested
481
+ if return_kv:
482
+ return y, k, v
483
+
484
+ return y
485
+
486
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
487
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
488
+
489
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
490
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
491
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
492
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
493
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
494
+
495
+ # append cached keys and values with new keys and values
496
+ if k_cache is not None:
497
+ k = torch.cat((k_cache, k), dim=2)
498
+ if v_cache is not None:
499
+ v = torch.cat((v_cache, v), dim=2)
500
+
501
+ # manual implementation of attention
502
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
503
+ att = F.softmax(att, dim=-1)
504
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
505
+
506
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
507
+
508
+ # output projection
509
+ y = self.c_proj(y)
510
+
511
+ return y, k, v
512
+
513
+
514
+ class MLP(nn.Module):
515
+
516
+ def __init__(self, config):
517
+ super().__init__()
518
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
519
+ self.gelu = nn.SiLU()
520
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
521
+ self.dropout = nn.Dropout(config.dropout)
522
+
523
+ def forward(self, x):
524
+ x = self.c_fc(x)
525
+ x = self.gelu(x)
526
+ x = self.c_proj(x)
527
+ x = self.dropout(x)
528
+ return x
529
+
530
+
531
+ class Rotary(torch.nn.Module):
532
+ def __init__(self, dim, base=10000, learned=False):
533
+ super().__init__()
534
+ # Compute the base inverse frequencies as before.
535
+ # self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
536
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
537
+ self.register_buffer("inv_freq", inv_freq)
538
+
539
+ def forward(self, x):
540
+ seq_len = x.shape[1]
541
+ # Create a tensor of positions.
542
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
543
+ # Outer product to compute angles; this uses the (possibly learnable) frequencies.
544
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
545
+ cos_cached = freqs.cos()
546
+ sin_cached = freqs.sin()
547
+ return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
548
+
549
+
550
+ def apply_rotary_emb(x, cos, sin):
551
+ assert x.ndim == 4 # multihead attention expected
552
+ d = x.shape[3] // 2
553
+ x1 = x[..., :d]
554
+ x2 = x[..., d:]
555
+ y1 = x1 * cos + x2 * sin
556
+ y2 = x1 * (-sin) + x2 * cos
557
+ return torch.cat([y1, y2], dim=3)
558
+
559
+
560
+ class RMSNorm(nn.Module):
561
+ """ Root Mean Square Normalization """
562
+ def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
563
+ super().__init__()
564
+ self.eps = eps
565
+ self.weight = nn.Parameter(torch.ones(dim)) if weight else None
566
+
567
+ def _norm(self, x):
568
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
569
+
570
+ def forward(self, x):
571
+ output = self._norm(x.float()).type_as(x)
572
+ if self.weight is not None:
573
+ return output * self.weight
574
+ return output
575
+
576
+ class DWA(nn.Module):
577
+ """ Depth Weighted Average layer that averages representations across the layers of a transformer """
578
+ """ From: https://arxiv.org/pdf/2402.02622"""
579
+
580
+ def __init__(self, n_layers: int):
581
+ super().__init__()
582
+ self.alphas = nn.Parameter(torch.zeros(n_layers, n_layers))
583
+ self.alphas.data = torch.eye(n_layers)
584
+ self.accumulators = []
585
+
586
+ def init_accumulators(self, x):
587
+ self.accumulators = [x]
588
+ return x * self.alphas[0, 0]
589
+
590
+ def forward(self, x):
591
+ self.accumulators.append(x)
592
+ x = 0.0
593
+ for i in range(len(self.accumulators)):
594
+ x = x + self.alphas[i, len(self.accumulators)-1] * self.accumulators[i]
595
+ return x
596
+