klemenk commited on
Commit
eb83fd9
·
verified ·
1 Parent(s): c977eda

Create modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +517 -0
modeling_auristream.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream sequence model definition.
3
+ """
4
+
5
+ import math
6
+ import inspect
7
+ import random
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import numpy as np
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
14
+ from transformers import PreTrainedModel
15
+ from .configuration_auristream import AuriStreamConfig
16
+
17
+
18
+ class AuriStream(PreTrainedModel):
19
+ config_class = AuriStreamConfig
20
+
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+ self.config = config
24
+
25
+ # if use_rope is in the config and false, initialize a wpe layer in transformer
26
+ if hasattr(config, 'use_rope') and not config.use_rope:
27
+ self.transformer = nn.ModuleDict(dict(
28
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
29
+ wpe = nn.Embedding(config.seq_len, config.n_embd),
30
+ drop = nn.Dropout(config.dropout),
31
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
32
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
33
+ ))
34
+ else:
35
+ self.transformer = nn.ModuleDict(dict(
36
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
37
+ drop = nn.Dropout(config.dropout),
38
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
39
+ ln_f = RMSNorm(config.n_embd, bias=config.bias),
40
+ ))
41
+
42
+ # check if n_pred_steps is defined in the config, this is the number of linear heads for prediction
43
+ if hasattr(config, 'n_pred_steps'):
44
+ self.future_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(config.n_pred_steps - 1)])
45
+ else:
46
+ self.future_heads = None
47
+
48
+ self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
49
+
50
+ # init all weights
51
+ self.apply(self._init_weights)
52
+ # apply special scaled init to the residual projections, per GPT-2 paper
53
+ for pn, p in self.named_parameters():
54
+ if pn.endswith('c_proj.weight'):
55
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
56
+
57
+ def get_num_params(self, non_embedding=True):
58
+ """
59
+ Return the number of parameters in the model.
60
+ For non-embedding count (default), the position embeddings get subtracted.
61
+ The token embeddings would too, except due to the parameter sharing these
62
+ params are actually used as weights in the final layer, so we include them.
63
+ """
64
+ n_params = sum(p.numel() for p in self.parameters())
65
+ return n_params
66
+
67
+ def _init_weights(self, module):
68
+ if isinstance(module, nn.Linear):
69
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
70
+ if module.bias is not None:
71
+ torch.nn.init.zeros_(module.bias)
72
+ elif isinstance(module, nn.Embedding):
73
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
+
75
+ def forward(self, seq, tgt=None, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
+ """
77
+ Input: coch: torch.Tensor of shape (b, t)
78
+ tgt_coch: torch.Tensor of shape (b, t) or None
79
+ """
80
+
81
+ # forward the GPT model itself
82
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
83
+
84
+ # if wpe exists in self.transformer apply leanred positional embedding
85
+ if hasattr(self.transformer, 'wpe'):
86
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
87
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
88
+ x = self.transformer.drop(tok_emb + pos_emb)
89
+ else:
90
+ x = self.transformer.drop(tok_emb)
91
+
92
+ all_hidden_states = []
93
+ for block_idx, block in enumerate(self.transformer.h):
94
+ # Forward the block
95
+ all_hidden_states.append(x)
96
+ if up_until_layer is not None and block_idx == up_until_layer:
97
+ break
98
+ x = block(x)
99
+
100
+ # append the last hidden state if we did not exit early
101
+ if up_until_layer is not None and block_idx != up_until_layer:
102
+ all_hidden_states.append(x)
103
+
104
+ if output_hidden_states:
105
+ model_output = BaseModelOutput(
106
+ last_hidden_state=x,
107
+ hidden_states=all_hidden_states,
108
+ )
109
+ return model_output
110
+
111
+ x = self.transformer.ln_f(x)
112
+ logits = self.coch_head(x)
113
+
114
+ if tgt is not None:
115
+ loss = F.cross_entropy(
116
+ logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
117
+ )
118
+
119
+ # If we have more than one future head, compute the loss for each head
120
+ if self.future_heads is not None:
121
+ for i, head in enumerate(self.future_heads):
122
+ future_logits = head(x[:, :-(i+1)])
123
+ loss += F.cross_entropy(
124
+ future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
125
+ )
126
+ # divide loss by number of future heads
127
+ loss = loss / (len(self.future_heads) + 1)
128
+
129
+ if return_dict:
130
+ model_output = CausalLMOutput(
131
+ loss=loss,
132
+ logits=logits,
133
+ )
134
+ return model_output
135
+
136
+ return logits, loss
137
+
138
+ return logits, None
139
+
140
+ def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
141
+ top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
142
+ """
143
+ Samples an integer from the distribution of logits
144
+ Parameters:
145
+ logits (torch.FloatTensor): The logits of the distribution
146
+ temp (float): The temperature of the sampling, if 0.0, then argmax is used
147
+ top_k (int): The number of top k tokens to consider during sampling
148
+ top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
149
+ Returns:
150
+ torch.LongTensor: The sampled integer
151
+ """
152
+ # If temperature is 0.0, use argmax
153
+ if temperature == 0.0:
154
+ return torch.argmax(logits, dim=-1)
155
+
156
+ # Apply temperature
157
+ logits = logits / temperature
158
+
159
+ # Apply top-k filtering if specified
160
+ if top_k is not None:
161
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
162
+ logits[logits < v[..., [-1]]] = -float('Inf')
163
+
164
+ # Apply top-p (nucleus) filtering if specified
165
+ if top_p is not None:
166
+ # Sort the logits in descending order
167
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
168
+ # Compute the sorted softmax probabilities
169
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
170
+ # Compute the cumulative probabilities
171
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
172
+ # Create a mask for tokens to remove
173
+ sorted_indices_to_remove = cumulative_probs > top_p
174
+ # Shift the mask right to keep at least one token
175
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
176
+ sorted_indices_to_remove[..., 0] = 0
177
+ # Scatter the mask back to the original indices
178
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
179
+ logits[indices_to_remove] = -float('Inf')
180
+
181
+ # Compute softmax probabilities
182
+ probs = F.softmax(logits, dim=-1)
183
+ # Flatten probabilities to (batch_size * sequence_length, vocab_size)
184
+ flat_probs = probs.view(-1, probs.size(-1))
185
+ # Sample from the distribution
186
+ sampled = torch.multinomial(flat_probs, num_samples=1)
187
+ # Reshape to original shape except for the last dimension
188
+ sampled = sampled.view(*logits.shape[:-1])
189
+ return sampled
190
+
191
+ @torch.no_grad()
192
+ def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
193
+ top_k=500, top_p=0.5, seed=None):
194
+ """
195
+ Parameters:
196
+ seq: torch.Tensor of shape (b, t, n_freq_bins)
197
+ Input cochleagram to use for generation
198
+ n_tokens: int
199
+ Number of time bins to predict
200
+ temp: float
201
+ Temperature for sampling logits
202
+ seed: int
203
+ Random seed for sampling
204
+
205
+ Returns:
206
+ pred_coch: torch.Tensor of shape (b, t, n_freq_bins)
207
+ The predicted cochleagram
208
+ all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins)
209
+ The logits for each time step
210
+ all_embs: (optional if return_embs is not None) list of torch.Tensor
211
+ The embeddings for each transformer block
212
+ """
213
+
214
+ # Set seed if provided
215
+ if seed is not None:
216
+ random.seed(seed)
217
+ np.random.seed(seed)
218
+ torch.manual_seed(seed)
219
+
220
+ # make a list of logits to return
221
+ all_logits = []
222
+ device = seq.device
223
+
224
+ # grab shape of the cochleagram
225
+ b, t = seq.size()
226
+
227
+ # TODO: double check this works then delete the block bellow:
228
+ # pass the given input through the model to get the predictions and cache
229
+ # the k and v values for each transformer block in the process
230
+ # pos = torch.arange(0, t, dtype=torch.long, device=device)
231
+ # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
232
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
233
+ # x = self.transformer.drop(tok_emb + pos_emb)
234
+
235
+ #### Embed conditioning sequence into KV cache
236
+
237
+ tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
238
+ # if wpe exists in self.transformer apply leanred positional embedding
239
+ if hasattr(self.transformer, 'wpe'):
240
+ pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
241
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
242
+ x = self.transformer.drop(tok_emb + pos_emb)
243
+ else:
244
+ x = self.transformer.drop(tok_emb)
245
+
246
+ # Initialize list to store k and v for each transformer block
247
+ k_list = []
248
+ v_list = []
249
+ for block_idx, block in enumerate(self.transformer.h):
250
+ # Pass through the transformer block, and store k and v
251
+ x, k, v = block(x, return_kv=True)
252
+ k_list.append(k)
253
+ v_list.append(v)
254
+ # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head)
255
+ k_cache = torch.stack(k_list, dim=0)
256
+ v_cache = torch.stack(v_list, dim=0)
257
+ # Pass through the final layer norm
258
+ x = self.transformer.ln_f(x)
259
+
260
+ # First prediction of the model is the decoding of the last time bin
261
+ logits = self.coch_head(x[:, [-1]])
262
+ predictions = [self.sample_logits(logits, temperature=temp)]
263
+ all_logits.append(logits)
264
+
265
+ ### Predict future tokens
266
+
267
+ # Now we pass the last time bin through the model to predict the next time bin
268
+ # we subtract 1 from max_new_tokens because we already predicted the first time bin
269
+ # using the last embedding of the input
270
+ for i in range(n_tokens-1):
271
+
272
+ # TODO: double check this works then delete the block bellow:
273
+ # # Get the emb and pos embedding of just the last token
274
+ # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
275
+ # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
276
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
277
+ # x = self.transformer.drop(tok_emb + pos_emb)
278
+
279
+ # Get the emb and pos embedding of just the last token
280
+ tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
281
+ # if wpe exists in self.transformer apply leanred positional embedding
282
+ if hasattr(self.transformer, 'wpe'):
283
+ pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
284
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
285
+ x = self.transformer.drop(tok_emb + pos_emb)
286
+ else:
287
+ x = self.transformer.drop(tok_emb)
288
+
289
+ # Pass through transformer block
290
+ k_list = []
291
+ v_list = []
292
+ for block_idx, block in enumerate(self.transformer.h):
293
+ x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
294
+ k_list.append(k)
295
+ v_list.append(v)
296
+ x = self.transformer.ln_f(x)
297
+ # create the cache with the new embeddings
298
+ k_cache = torch.stack(k_list, dim=0)
299
+ v_cache = torch.stack(v_list, dim=0)
300
+ # predict next time bin
301
+ logits = self.coch_head(x)
302
+ predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
303
+ all_logits.append(logits)
304
+
305
+ pred_coch = torch.cat(predictions, dim=1)
306
+ all_logits = torch.cat(all_logits, dim=1)
307
+
308
+ return pred_coch, all_logits
309
+
310
+
311
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
312
+ # start with all of the candidate parameters
313
+ param_dict = {pn: p for pn, p in self.named_parameters()}
314
+ # filter out those that do not require grad
315
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
316
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
317
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
318
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
319
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
320
+ optim_groups = [
321
+ {'params': decay_params, 'weight_decay': weight_decay},
322
+ {'params': nodecay_params, 'weight_decay': 0.0}
323
+ ]
324
+ num_decay_params = sum(p.numel() for p in decay_params)
325
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
326
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
327
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
328
+ # Create AdamW optimizer and use the fused version if it is available
329
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
330
+ use_fused = fused_available and device_type == 'cuda'
331
+ extra_args = dict(fused=True) if use_fused else dict()
332
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
333
+ print(f"using fused AdamW: {use_fused}")
334
+
335
+ return optimizer
336
+
337
+ def estimate_mfu(self, fwdbwd_per_iter, T, dt, gpu_type='A40'):
338
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
339
+ # first estimate the number of flops we do per iteration.
340
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
341
+ N = self.unsharded_param_count
342
+ cfg = self.config
343
+ L, H, Q = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head
344
+ # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
345
+ flops_per_token = 6*N + 12*L*H*Q*T
346
+ flops_per_fwdbwd = flops_per_token * T
347
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
348
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
349
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
350
+
351
+ # grab promised flops based on GPU type
352
+ if gpu_type == 'A40':
353
+ flops_promised = 149.7e12 # A40 GPU bfloat16 peak flops is 149.7 TFLOPS
354
+ elif gpu_type == 'A100':
355
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
356
+ elif gpu_type == 'H100':
357
+ flops_promised = 756e12 # H100 GPU bfloat16 peak flops is 756 TFLOPS
358
+ elif gpu_type == 'TPUv4':
359
+ flops_promised = 275e12
360
+ elif gpu_type == 'TPUv5e':
361
+ flops_promised = 197e12
362
+
363
+ mfu = flops_achieved / flops_promised
364
+ return mfu
365
+
366
+
367
+ #########################################################
368
+ ##### Layer Definitions #####
369
+ #########################################################
370
+
371
+
372
+ class Block(nn.Module):
373
+
374
+ def __init__(self, config):
375
+ super().__init__()
376
+ self.attn = CausalSelfAttention(config)
377
+ self.mlp = MLP(config)
378
+ self.attn_scale = 1.0 # (1 / (2 * config.n_layer)**0.5)
379
+ self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
380
+ self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
381
+
382
+ def forward(self, x, return_kv=False, k_cache=None, v_cache=None):
383
+ # If we are given a key and value cache, we will use the pre-computed values to minimize
384
+ # the computation cost
385
+ if k_cache is not None and v_cache is not None:
386
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
387
+ x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache)
388
+ x = x + x_attn
389
+ x = x + self.mlp(self.norm2(x))
390
+ return x, k, v
391
+ # We might want to encode the caches of a whole block of keys and values at once using the
392
+ # fast flash attention impelmentation while still returning the key and value caches
393
+ elif return_kv:
394
+ # Pass the key and value cache to the attention layer, obtain new key and value caches
395
+ x_attn, k, v = self.attn(self.norm1(x), return_kv=True)
396
+ x = x + x_attn
397
+ x = x + self.mlp(self.norm2(x))
398
+ return x, k, v
399
+
400
+ x = x + self.attn_scale * self.attn(self.norm1(x))
401
+ x = x + self.mlp(self.norm2(x))
402
+ return x
403
+
404
+
405
+ class CausalSelfAttention(nn.Module):
406
+
407
+ def __init__(self, config):
408
+ super().__init__()
409
+ self.n_head = config.n_head
410
+ self.n_embd = config.n_embd
411
+ self.head_dim = self.n_embd // self.n_head
412
+ assert self.n_embd % self.n_head == 0
413
+ # key, query, value projections for all heads, but in a batch
414
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
415
+ # output projection
416
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
417
+
418
+ rope_theta = 500000
419
+ if hasattr(config, 'rope_theta') and config.rope_theta is not None:
420
+ rope_theta = config.rope_theta
421
+
422
+ self.rotary = Rotary(self.head_dim, base=rope_theta)
423
+
424
+ if hasattr(config, 'use_rope') and not config.use_rope:
425
+ self.rotary = None
426
+
427
+ def forward(self, x, return_kv=False, ):
428
+
429
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
430
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
431
+ qkv = self.c_attn(x)
432
+ q, k, v = qkv.split(self.n_embd, dim=2)
433
+ k = k.view(B, T, self.n_head, self.head_dim)
434
+ q = q.view(B, T, self.n_head, self.head_dim)
435
+ v = v.view(B, T, self.n_head, self.head_dim)
436
+
437
+ if self.rotary is not None:
438
+ cos, sin = self.rotary(q)
439
+ q = apply_rotary_emb(q, cos, sin)
440
+ k = apply_rotary_emb(k, cos, sin)
441
+
442
+ y = F.scaled_dot_product_attention(
443
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
444
+ is_causal=True)
445
+
446
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
447
+ # output projection
448
+ y = self.c_proj(y)
449
+ return y
450
+
451
+
452
+ class MLP(nn.Module):
453
+
454
+ def __init__(self, config):
455
+ super().__init__()
456
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
457
+ self.gelu = nn.GELU()
458
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
459
+ self.dropout = nn.Dropout(config.dropout)
460
+
461
+ def forward(self, x):
462
+ x = self.c_fc(x)
463
+ x = self.gelu(x)
464
+ x = self.c_proj(x)
465
+ x = self.dropout(x)
466
+ return x
467
+
468
+
469
+ class Rotary(torch.nn.Module):
470
+ def __init__(self, dim, base=500000, learned=True):
471
+ super().__init__()
472
+ # Compute the base inverse frequencies as before.
473
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
474
+ # If learned is True, register as a parameter; otherwise, as a buffer.
475
+ if learned:
476
+ # Initialize randomly and register as a parameter.
477
+ self.inv_freq = torch.nn.Parameter(inv_freq)
478
+ nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
479
+ else:
480
+ self.register_buffer("inv_freq", inv_freq)
481
+ self.learned = learned # (optional) Save the flag if needed later
482
+
483
+ def forward(self, x):
484
+ seq_len = x.shape[1]
485
+ # Create a tensor of positions.
486
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
487
+ # Outer product to compute angles; this uses the (possibly learnable) frequencies.
488
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
489
+ cos_cached = freqs.cos()
490
+ sin_cached = freqs.sin()
491
+ return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
492
+
493
+ def apply_rotary_emb(x, cos, sin):
494
+ assert x.ndim == 4 # multihead attention expected
495
+ d = x.shape[3] // 2
496
+ x1 = x[..., :d]
497
+ x2 = x[..., d:]
498
+ y1 = x1 * cos + x2 * sin
499
+ y2 = x1 * (-sin) + x2 * cos
500
+ return torch.cat([y1, y2], dim=3)
501
+
502
+
503
+ class RMSNorm(nn.Module):
504
+ """ Root Mean Square Normalization """
505
+ def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
506
+ super().__init__()
507
+ self.eps = eps
508
+ self.weight = nn.Parameter(torch.ones(dim)) if weight else None
509
+
510
+ def _norm(self, x):
511
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
512
+
513
+ def forward(self, x):
514
+ output = self._norm(x.float()).type_as(x)
515
+ if self.weight is not None:
516
+ return output * self.weight
517
+ return output