klemenk commited on
Commit
4d46702
·
verified ·
1 Parent(s): a1c7f44

Create modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +516 -0
modeling_auristream.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream sequence model definition.
3
+ """
4
+
5
+ import math
6
+ import inspect
7
+ import random
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import numpy as np
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
14
+ from transformers import PreTrainedModel
15
+
16
+
17
+ class AuriStream(PreTrainedModel):
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.config = config
22
+
23
+ # if use_rope is in the config and false, initialize a wpe layer in transformer
24
+ if hasattr(config, 'use_rope') and not config.use_rope:
25
+ self.transformer = nn.ModuleDict(dict(
26
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
27
+ wpe = nn.Embedding(config.seq_len, config.n_embd),
28
+ drop = nn.Dropout(config.dropout),
29
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
30
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
31
+ ))
32
+ else:
33
+ self.transformer = nn.ModuleDict(dict(
34
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
35
+ drop = nn.Dropout(config.dropout),
36
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
37
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
38
+ ))
39
+
40
+ # check if n_pred_steps is defined in the config, this is the number of linear heads for prediction
41
+ if hasattr(config, 'n_pred_steps'):
42
+ self.future_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(config.n_pred_steps - 1)])
43
+ else:
44
+ self.future_heads = None
45
+
46
+ self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
47
+
48
+ # init all weights
49
+ self.apply(self._init_weights)
50
+ # apply special scaled init to the residual projections, per GPT-2 paper
51
+ for pn, p in self.named_parameters():
52
+ if pn.endswith('c_proj.weight'):
53
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
54
+
55
+ def get_num_params(self, non_embedding=True):
56
+ """
57
+ Return the number of parameters in the model.
58
+ For non-embedding count (default), the position embeddings get subtracted.
59
+ The token embeddings would too, except due to the parameter sharing these
60
+ params are actually used as weights in the final layer, so we include them.
61
+ """
62
+ n_params = sum(p.numel() for p in self.parameters())
63
+ return n_params
64
+
65
+ def _init_weights(self, module):
66
+ if isinstance(module, nn.Linear):
67
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
68
+ if module.bias is not None:
69
+ torch.nn.init.zeros_(module.bias)
70
+ elif isinstance(module, nn.Embedding):
71
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
72
+
73
+ def forward(self, seq, tgt=None, output_hidden_states=False, return_dict=False, up_until_layer=None):
74
+ """
75
+ Input: coch: torch.Tensor of shape (b, t)
76
+ tgt_coch: torch.Tensor of shape (b, t) or None
77
+ """
78
+
79
+ # forward the GPT model itself
80
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
81
+
82
+ # if wpe exists in self.transformer apply leanred positional embedding
83
+ if hasattr(self.transformer, 'wpe'):
84
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
85
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
86
+ x = self.transformer.drop(tok_emb + pos_emb)
87
+ else:
88
+ x = self.transformer.drop(tok_emb)
89
+
90
+ all_hidden_states = []
91
+ for block_idx, block in enumerate(self.transformer.h):
92
+ # Forward the block
93
+ all_hidden_states.append(x)
94
+ if up_until_layer is not None and block_idx == up_until_layer:
95
+ break
96
+ x = block(x)
97
+
98
+ # append the last hidden state if we did not exit early
99
+ if up_until_layer is not None and block_idx != up_until_layer:
100
+ all_hidden_states.append(x)
101
+
102
+ if output_hidden_states:
103
+ model_output = BaseModelOutput(
104
+ last_hidden_state=x,
105
+ hidden_states=all_hidden_states,
106
+ )
107
+ return model_output
108
+
109
+ x = self.transformer.ln_f(x)
110
+ logits = self.coch_head(x)
111
+
112
+ if tgt is not None:
113
+ loss = F.cross_entropy(
114
+ logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
115
+ )
116
+
117
+ # If we have more than one future head, compute the loss for each head
118
+ if self.future_heads is not None:
119
+ for i, head in enumerate(self.future_heads):
120
+ future_logits = head(x[:, :-(i+1)])
121
+ loss += F.cross_entropy(
122
+ future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
123
+ )
124
+ # divide loss by number of future heads
125
+ loss = loss / (len(self.future_heads) + 1)
126
+
127
+ if return_dict:
128
+ model_output = CausalLMOutput(
129
+ loss=loss,
130
+ logits=logits,
131
+ )
132
+ return model_output
133
+
134
+ return logits, loss
135
+
136
+ return logits, None
137
+
138
+ def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
139
+ top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
140
+ """
141
+ Samples an integer from the distribution of logits
142
+
143
+ Parameters:
144
+ logits (torch.FloatTensor): The logits of the distribution
145
+ temp (float): The temperature of the sampling, if 0.0, then argmax is used
146
+ top_k (int): The number of top k tokens to consider during sampling
147
+ top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
148
+ Returns:
149
+ torch.LongTensor: The sampled integer
150
+ """
151
+ # If temperature is 0.0, use argmax
152
+ if temperature == 0.0:
153
+ return torch.argmax(logits, dim=-1)
154
+
155
+ # Apply temperature
156
+ logits = logits / temperature
157
+
158
+ # Apply top-k filtering if specified
159
+ if top_k is not None:
160
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
161
+ logits[logits < v[..., [-1]]] = -float('Inf')
162
+
163
+ # Apply top-p (nucleus) filtering if specified
164
+ if top_p is not None:
165
+ # Sort the logits in descending order
166
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
167
+ # Compute the sorted softmax probabilities
168
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
169
+ # Compute the cumulative probabilities
170
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
171
+ # Create a mask for tokens to remove
172
+ sorted_indices_to_remove = cumulative_probs > top_p
173
+ # Shift the mask right to keep at least one token
174
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
175
+ sorted_indices_to_remove[..., 0] = 0
176
+ # Scatter the mask back to the original indices
177
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
178
+ logits[indices_to_remove] = -float('Inf')
179
+
180
+ # Compute softmax probabilities
181
+ probs = F.softmax(logits, dim=-1)
182
+ # Flatten probabilities to (batch_size * sequence_length, vocab_size)
183
+ flat_probs = probs.view(-1, probs.size(-1))
184
+ # Sample from the distribution
185
+ sampled = torch.multinomial(flat_probs, num_samples=1)
186
+ # Reshape to original shape except for the last dimension
187
+ sampled = sampled.view(*logits.shape[:-1])
188
+ return sampled
189
+
190
+ @torch.no_grad()
191
+ def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
192
+ top_k=500, top_p=0.5, seed=None):
193
+ """
194
+ Parameters:
195
+ seq: torch.Tensor of shape (b, t, n_freq_bins)
196
+ Input cochleagram to use for generation
197
+ n_tokens: int
198
+ Number of time bins to predict
199
+ temp: float
200
+ Temperature for sampling logits
201
+ seed: int
202
+ Random seed for sampling
203
+
204
+ Returns:
205
+ pred_coch: torch.Tensor of shape (b, t, n_freq_bins)
206
+ The predicted cochleagram
207
+ all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins)
208
+ The logits for each time step
209
+ all_embs: (optional if return_embs is not None) list of torch.Tensor
210
+ The embeddings for each transformer block
211
+ """
212
+
213
+ # Set seed if provided
214
+ if seed is not None:
215
+ random.seed(seed)
216
+ np.random.seed(seed)
217
+ torch.manual_seed(seed)
218
+
219
+ # make a list of logits to return
220
+ all_logits = []
221
+ device = seq.device
222
+
223
+ # grab shape of the cochleagram
224
+ b, t = seq.size()
225
+
226
+ # TODO: double check this works then delete the block bellow:
227
+ # pass the given input through the model to get the predictions and cache
228
+ # the k and v values for each transformer block in the process
229
+ # pos = torch.arange(0, t, dtype=torch.long, device=device)
230
+ # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
231
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
232
+ # x = self.transformer.drop(tok_emb + pos_emb)
233
+
234
+ #### Embed conditioning sequence into KV cache
235
+
236
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
237
+ # if wpe exists in self.transformer apply leanred positional embedding
238
+ if hasattr(self.transformer, 'wpe'):
239
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
240
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
241
+ x = self.transformer.drop(tok_emb + pos_emb)
242
+ else:
243
+ x = self.transformer.drop(tok_emb)
244
+
245
+ # Initialize list to store k and v for each transformer block
246
+ k_list = []
247
+ v_list = []
248
+ for block_idx, block in enumerate(self.transformer.h):
249
+ # Pass through the transformer block, and store k and v
250
+ x, k, v = block(x, return_kv=True)
251
+ k_list.append(k)
252
+ v_list.append(v)
253
+ # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head)
254
+ k_cache = torch.stack(k_list, dim=0)
255
+ v_cache = torch.stack(v_list, dim=0)
256
+ # Pass through the final layer norm
257
+ x = self.transformer.ln_f(x)
258
+
259
+ # First prediction of the model is the decoding of the last time bin
260
+ logits = self.coch_head(x[:, [-1]])
261
+ predictions = [self.sample_logits(logits, temperature=temp)]
262
+ all_logits.append(logits)
263
+
264
+ ### Predict future tokens
265
+
266
+ # Now we pass the last time bin through the model to predict the next time bin
267
+ # we subtract 1 from max_new_tokens because we already predicted the first time bin
268
+ # using the last embedding of the input
269
+ for i in range(n_tokens-1):
270
+
271
+ # TODO: double check this works then delete the block bellow:
272
+ # # Get the emb and pos embedding of just the last token
273
+ # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
274
+ # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
275
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
276
+ # x = self.transformer.drop(tok_emb + pos_emb)
277
+
278
+ # Get the emb and pos embedding of just the last token
279
+ tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
280
+ # if wpe exists in self.transformer apply leanred positional embedding
281
+ if hasattr(self.transformer, 'wpe'):
282
+ pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
283
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
284
+ x = self.transformer.drop(tok_emb + pos_emb)
285
+ else:
286
+ x = self.transformer.drop(tok_emb)
287
+
288
+ # Pass through transformer block
289
+ k_list = []
290
+ v_list = []
291
+ for block_idx, block in enumerate(self.transformer.h):
292
+ x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
293
+ k_list.append(k)
294
+ v_list.append(v)
295
+ x = self.transformer.ln_f(x)
296
+ # create the cache with the new embeddings
297
+ k_cache = torch.stack(k_list, dim=0)
298
+ v_cache = torch.stack(v_list, dim=0)
299
+ # predict next time bin
300
+ logits = self.coch_head(x)
301
+ predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
302
+ all_logits.append(logits)
303
+
304
+ pred_coch = torch.cat(predictions, dim=1)
305
+ all_logits = torch.cat(all_logits, dim=1)
306
+
307
+ return pred_coch, all_logits
308
+
309
+
310
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
311
+ # start with all of the candidate parameters
312
+ param_dict = {pn: p for pn, p in self.named_parameters()}
313
+ # filter out those that do not require grad
314
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
315
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
316
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
317
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
318
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
319
+ optim_groups = [
320
+ {'params': decay_params, 'weight_decay': weight_decay},
321
+ {'params': nodecay_params, 'weight_decay': 0.0}
322
+ ]
323
+ num_decay_params = sum(p.numel() for p in decay_params)
324
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
325
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
326
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
327
+ # Create AdamW optimizer and use the fused version if it is available
328
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
329
+ use_fused = fused_available and device_type == 'cuda'
330
+ extra_args = dict(fused=True) if use_fused else dict()
331
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
332
+ print(f"using fused AdamW: {use_fused}")
333
+
334
+ return optimizer
335
+
336
+ def estimate_mfu(self, fwdbwd_per_iter, T, dt, gpu_type='A40'):
337
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
338
+ # first estimate the number of flops we do per iteration.
339
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
340
+ N = self.unsharded_param_count
341
+ cfg = self.config
342
+ L, H, Q = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head
343
+ # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
344
+ flops_per_token = 6*N + 12*L*H*Q*T
345
+ flops_per_fwdbwd = flops_per_token * T
346
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
347
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
348
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
349
+
350
+ # grab promised flops based on GPU type
351
+ if gpu_type == 'A40':
352
+ flops_promised = 149.7e12 # A40 GPU bfloat16 peak flops is 149.7 TFLOPS
353
+ elif gpu_type == 'A100':
354
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
355
+ elif gpu_type == 'H100':
356
+ flops_promised = 756e12 # H100 GPU bfloat16 peak flops is 756 TFLOPS
357
+ elif gpu_type == 'TPUv4':
358
+ flops_promised = 275e12
359
+ elif gpu_type == 'TPUv5e':
360
+ flops_promised = 197e12
361
+
362
+ mfu = flops_achieved / flops_promised
363
+ return mfu
364
+
365
+
366
+ #########################################################
367
+ ##### Layer Definitions #####
368
+ #########################################################
369
+
370
+
371
+ class Block(nn.Module):
372
+
373
+ def __init__(self, config):
374
+ super().__init__()
375
+ self.attn = CausalSelfAttention(config)
376
+ self.mlp = MLP(config)
377
+ self.attn_scale = 1.0 # (1 / (2 * config.n_layer)**0.5)
378
+ self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
379
+ self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
380
+
381
+ def forward(self, x, return_kv=False, k_cache=None, v_cache=None):
382
+ # If we are given a key and value cache, we will use the pre-computed values to minimize
383
+ # the computation cost
384
+ if k_cache is not None and v_cache is not None:
385
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
386
+ x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache)
387
+ x = x + x_attn
388
+ x = x + self.mlp(self.norm2(x))
389
+ return x, k, v
390
+ # We might want to encode the caches of a whole block of keys and values at once using the
391
+ # fast flash attention impelmentation while still returning the key and value caches
392
+ elif return_kv:
393
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
394
+ x_attn, k, v = self.attn(self.norm1(x), return_kv=True)
395
+ x = x + x_attn
396
+ x = x + self.mlp(self.norm2(x))
397
+ return x, k, v
398
+
399
+ x = x + self.attn_scale * self.attn(self.norm1(x))
400
+ x = x + self.mlp(self.norm2(x))
401
+ return x
402
+
403
+
404
+ class CausalSelfAttention(nn.Module):
405
+
406
+ def __init__(self, config):
407
+ super().__init__()
408
+ self.n_head = config.n_head
409
+ self.n_embd = config.n_embd
410
+ self.head_dim = self.n_embd // self.n_head
411
+ assert self.n_embd % self.n_head == 0
412
+ # key, query, value projections for all heads, but in a batch
413
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
414
+ # output projection
415
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
416
+
417
+ rope_theta = 500000
418
+ if hasattr(config, 'rope_theta') and config.rope_theta is not None:
419
+ rope_theta = config.rope_theta
420
+
421
+ self.rotary = Rotary(self.head_dim, base=rope_theta)
422
+
423
+ if hasattr(config, 'use_rope') and not config.use_rope:
424
+ self.rotary = None
425
+
426
+ def forward(self, x, return_kv=False, ):
427
+
428
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
429
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
430
+ qkv = self.c_attn(x)
431
+ q, k, v = qkv.split(self.n_embd, dim=2)
432
+ k = k.view(B, T, self.n_head, self.head_dim)
433
+ q = q.view(B, T, self.n_head, self.head_dim)
434
+ v = v.view(B, T, self.n_head, self.head_dim)
435
+
436
+ if self.rotary is not None:
437
+ cos, sin = self.rotary(q)
438
+ q = apply_rotary_emb(q, cos, sin)
439
+ k = apply_rotary_emb(k, cos, sin)
440
+
441
+ y = F.scaled_dot_product_attention(
442
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
443
+ is_causal=True)
444
+
445
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
446
+ # output projection
447
+ y = self.c_proj(y)
448
+ return y
449
+
450
+
451
+ class MLP(nn.Module):
452
+
453
+ def __init__(self, config):
454
+ super().__init__()
455
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
456
+ self.gelu = nn.GELU()
457
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
458
+ self.dropout = nn.Dropout(config.dropout)
459
+
460
+ def forward(self, x):
461
+ x = self.c_fc(x)
462
+ x = self.gelu(x)
463
+ x = self.c_proj(x)
464
+ x = self.dropout(x)
465
+ return x
466
+
467
+
468
+ class Rotary(torch.nn.Module):
469
+ def __init__(self, dim, base=500000, learned=True):
470
+ super().__init__()
471
+ # Compute the base inverse frequencies as before.
472
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
473
+ # If learned is True, register as a parameter; otherwise, as a buffer.
474
+ if learned:
475
+ # Initialize randomly and register as a parameter.
476
+ self.inv_freq = torch.nn.Parameter(inv_freq)
477
+ nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
478
+ else:
479
+ self.register_buffer("inv_freq", inv_freq)
480
+ self.learned = learned # (optional) Save the flag if needed later
481
+
482
+ def forward(self, x):
483
+ seq_len = x.shape[1]
484
+ # Create a tensor of positions.
485
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
486
+ # Outer product to compute angles; this uses the (possibly learnable) frequencies.
487
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
488
+ cos_cached = freqs.cos()
489
+ sin_cached = freqs.sin()
490
+ return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
491
+
492
+ def apply_rotary_emb(x, cos, sin):
493
+ assert x.ndim == 4 # multihead attention expected
494
+ d = x.shape[3] // 2
495
+ x1 = x[..., :d]
496
+ x2 = x[..., d:]
497
+ y1 = x1 * cos + x2 * sin
498
+ y2 = x1 * (-sin) + x2 * cos
499
+ return torch.cat([y1, y2], dim=3)
500
+
501
+
502
+ class RMSNorm(nn.Module):
503
+ """ Root Mean Square Normalization """
504
+ def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
505
+ super().__init__()
506
+ self.eps = eps
507
+ self.weight = nn.Parameter(torch.ones(dim)) if weight else None
508
+
509
+ def _norm(self, x):
510
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
511
+
512
+ def forward(self, x):
513
+ output = self._norm(x.float()).type_as(x)
514
+ if self.weight is not None:
515
+ return output * self.weight
516
+ return output