klemenk commited on
Commit
943347f
·
verified ·
1 Parent(s): 9194438

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +621 -51
modeling_auristream.py CHANGED
@@ -1,3 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  AuriStream sequence model definition.
3
  """
@@ -248,7 +844,7 @@ class AuriStream(PreTrainedModel):
248
  v_list = []
249
  for block_idx, block in enumerate(self.transformer.h):
250
  # Pass through the transformer block, and store k and v
251
- x, k, v = block(x, pos=pos, return_kv=True)
252
  k_list.append(k)
253
  v_list.append(v)
254
  # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head)
@@ -290,7 +886,7 @@ class AuriStream(PreTrainedModel):
290
  k_list = []
291
  v_list = []
292
  for block_idx, block in enumerate(self.transformer.h):
293
- x, k, v = block(x, pos=pos, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
294
  k_list.append(k)
295
  v_list.append(v)
296
  x = self.transformer.ln_f(x)
@@ -300,8 +896,6 @@ class AuriStream(PreTrainedModel):
300
  # predict next time bin
301
  logits = self.coch_head(x)
302
  predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
303
- print(f"logits {logits.argmax()}")
304
- lk
305
  all_logits.append(logits)
306
 
307
  pred_coch = torch.cat(predictions, dim=1)
@@ -381,12 +975,12 @@ class Block(nn.Module):
381
  self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
382
  self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
383
 
384
- def forward(self, x, pos=None, return_kv=False, k_cache=None, v_cache=None):
385
  # If we are given a key and value cache, we will use the pre-computed values to minimize
386
  # the computation cost
387
  if k_cache is not None and v_cache is not None:
388
  # Pass the key and value cache to the attention layer, obtain new key and value caches
389
- x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), pos=pos, k_cache=k_cache, v_cache=v_cache)
390
  x = x + x_attn
391
  x = x + self.mlp(self.norm2(x))
392
  return x, k, v
@@ -474,52 +1068,29 @@ class CausalSelfAttention(nn.Module):
474
 
475
  return y
476
 
 
 
477
 
478
- def kv_cache_forward(
479
- self,
480
- x: torch.Tensor,
481
- pos: torch.Tensor,
482
- k_cache: torch.Tensor | None = None,
483
- v_cache: torch.Tensor | None = None,
484
- return_attn_maps: bool = False,
485
- ):
486
- B, T, C = x.size()
487
-
488
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
489
- q = q.view(B, T, self.n_head, self.head_dim) # (B, T, n_head, d)
490
- k = k.view(B, T, self.n_head, self.head_dim)
491
- v = v.view(B, T, self.n_head, self.head_dim)
492
-
493
- if self.rotary is not None:
494
- cos, sin = self.rotary(q, t=pos) # cos/sin match (B, T, n_head, d)
495
- q = apply_rotary_emb(q, cos, sin)
496
- k = apply_rotary_emb(k, cos, sin)
497
-
498
- q = q.transpose(1, 2) # (B, n_head, T, d)
499
- k = k.transpose(1, 2)
500
- v = v.transpose(1, 2)
501
 
 
502
  if k_cache is not None:
503
- k = torch.cat([k_cache, k], dim=2) # time dim grows
504
  if v_cache is not None:
505
- v = torch.cat([v_cache, v], dim=2)
506
 
507
- if not return_attn_maps:
508
- y = F.scaled_dot_product_attention(
509
- q, k, v,
510
- is_causal=True)
511
- else:
512
- # manual implementation of attention
513
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
514
- att = F.softmax(att, dim=-1)
515
- y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
516
-
517
- y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
518
-
519
- # output projection
520
- y = self.c_proj(y)
521
 
522
- y = y.transpose(1, 2).contiguous().view(B, T, C)
 
 
523
  y = self.c_proj(y)
524
 
525
  return y, k, v
@@ -556,11 +1127,10 @@ class Rotary(torch.nn.Module):
556
  self.register_buffer("inv_freq", inv_freq)
557
  self.learned = learned # (optional) Save the flag if needed later
558
 
559
- def forward(self, x, t=None):
560
  seq_len = x.shape[1]
561
- if t is None:
562
- # Create a tensor of positions.
563
- t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
564
  # Outer product to compute angles; this uses the (possibly learnable) frequencies.
565
  freqs = torch.outer(t, self.inv_freq).to(x.device)
566
  cos_cached = freqs.cos()
@@ -591,4 +1161,4 @@ class RMSNorm(nn.Module):
591
  output = self._norm(x.float()).type_as(x)
592
  if self.weight is not None:
593
  return output * self.weight
594
- return output
 
1
+ # """
2
+ # AuriStream sequence model definition.
3
+ # """
4
+
5
+ # import math
6
+ # import inspect
7
+ # import random
8
+ # import torch
9
+ # import torch.nn as nn
10
+ # from torch.nn import functional as F
11
+ # import numpy as np
12
+ # from huggingface_hub import PyTorchModelHubMixin
13
+ # from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
14
+ # from transformers import PreTrainedModel
15
+ # from .configuration_auristream import AuriStreamConfig
16
+
17
+
18
+ # class AuriStream(PreTrainedModel):
19
+ # config_class = AuriStreamConfig
20
+
21
+ # def __init__(self, config):
22
+ # super().__init__(config)
23
+ # self.config = config
24
+
25
+ # # if use_rope is in the config and false, initialize a wpe layer in transformer
26
+ # if hasattr(config, 'use_rope') and not config.use_rope:
27
+ # self.transformer = nn.ModuleDict(dict(
28
+ # wte = nn.Embedding(config.vocab_size, config.n_embd),
29
+ # wpe = nn.Embedding(config.seq_len, config.n_embd),
30
+ # drop = nn.Dropout(config.dropout),
31
+ # h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
32
+ # ln_f = RMSNorm(config.n_embd, bias=config.bias),
33
+ # ))
34
+ # else:
35
+ # self.transformer = nn.ModuleDict(dict(
36
+ # wte = nn.Embedding(config.vocab_size, config.n_embd),
37
+ # drop = nn.Dropout(config.dropout),
38
+ # h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
39
+ # ln_f = RMSNorm(config.n_embd, bias=config.bias),
40
+ # ))
41
+
42
+ # # check if n_pred_steps is defined in the config, this is the number of linear heads for prediction
43
+ # if hasattr(config, 'n_pred_steps'):
44
+ # self.future_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(config.n_pred_steps - 1)])
45
+ # else:
46
+ # self.future_heads = None
47
+
48
+ # self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
49
+
50
+ # # init all weights
51
+ # self.apply(self._init_weights)
52
+ # # apply special scaled init to the residual projections, per GPT-2 paper
53
+ # for pn, p in self.named_parameters():
54
+ # if pn.endswith('c_proj.weight'):
55
+ # torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
56
+
57
+ # def get_num_params(self, non_embedding=True):
58
+ # """
59
+ # Return the number of parameters in the model.
60
+ # For non-embedding count (default), the position embeddings get subtracted.
61
+ # The token embeddings would too, except due to the parameter sharing these
62
+ # params are actually used as weights in the final layer, so we include them.
63
+ # """
64
+ # n_params = sum(p.numel() for p in self.parameters())
65
+ # return n_params
66
+
67
+ # def _init_weights(self, module):
68
+ # if isinstance(module, nn.Linear):
69
+ # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
70
+ # if module.bias is not None:
71
+ # torch.nn.init.zeros_(module.bias)
72
+ # elif isinstance(module, nn.Embedding):
73
+ # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
+
75
+ # def forward(self, seq, tgt=None, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
+ # """
77
+ # Input: coch: torch.Tensor of shape (b, t)
78
+ # tgt_coch: torch.Tensor of shape (b, t) or None
79
+ # """
80
+
81
+ # # forward the GPT model itself
82
+ # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
83
+
84
+ # # if wpe exists in self.transformer apply leanred positional embedding
85
+ # if hasattr(self.transformer, 'wpe'):
86
+ # pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
87
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
88
+ # x = self.transformer.drop(tok_emb + pos_emb)
89
+ # else:
90
+ # x = self.transformer.drop(tok_emb)
91
+
92
+ # all_hidden_states = []
93
+ # for block_idx, block in enumerate(self.transformer.h):
94
+ # # Forward the block
95
+ # all_hidden_states.append(x)
96
+ # if up_until_layer is not None and block_idx == up_until_layer:
97
+ # break
98
+ # x = block(x)
99
+
100
+ # # append the last hidden state if we did not exit early
101
+ # if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
102
+ # all_hidden_states.append(x)
103
+
104
+ # if output_hidden_states:
105
+ # model_output = BaseModelOutput(
106
+ # last_hidden_state=x,
107
+ # hidden_states=all_hidden_states,
108
+ # )
109
+ # return model_output
110
+
111
+ # x = self.transformer.ln_f(x)
112
+ # logits = self.coch_head(x)
113
+
114
+ # if tgt is not None:
115
+ # loss = F.cross_entropy(
116
+ # logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
117
+ # )
118
+
119
+ # # If we have more than one future head, compute the loss for each head
120
+ # if self.future_heads is not None:
121
+ # for i, head in enumerate(self.future_heads):
122
+ # future_logits = head(x[:, :-(i+1)])
123
+ # loss += F.cross_entropy(
124
+ # future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
125
+ # )
126
+ # # divide loss by number of future heads
127
+ # loss = loss / (len(self.future_heads) + 1)
128
+
129
+ # if return_dict:
130
+ # model_output = CausalLMOutput(
131
+ # loss=loss,
132
+ # logits=logits,
133
+ # )
134
+ # return model_output
135
+
136
+ # return logits, loss
137
+
138
+ # return logits, None
139
+
140
+ # def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
141
+ # top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
142
+ # """
143
+ # Samples an integer from the distribution of logits
144
+ # Parameters:
145
+ # logits (torch.FloatTensor): The logits of the distribution
146
+ # temp (float): The temperature of the sampling, if 0.0, then argmax is used
147
+ # top_k (int): The number of top k tokens to consider during sampling
148
+ # top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
149
+ # Returns:
150
+ # torch.LongTensor: The sampled integer
151
+ # """
152
+ # # If temperature is 0.0, use argmax
153
+ # if temperature == 0.0:
154
+ # return torch.argmax(logits, dim=-1)
155
+
156
+ # # Apply temperature
157
+ # logits = logits / temperature
158
+
159
+ # # Apply top-k filtering if specified
160
+ # if top_k is not None:
161
+ # v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
162
+ # logits[logits < v[..., [-1]]] = -float('Inf')
163
+
164
+ # # Apply top-p (nucleus) filtering if specified
165
+ # if top_p is not None:
166
+ # # Sort the logits in descending order
167
+ # sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
168
+ # # Compute the sorted softmax probabilities
169
+ # sorted_probs = F.softmax(sorted_logits, dim=-1)
170
+ # # Compute the cumulative probabilities
171
+ # cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
172
+ # # Create a mask for tokens to remove
173
+ # sorted_indices_to_remove = cumulative_probs > top_p
174
+ # # Shift the mask right to keep at least one token
175
+ # sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
176
+ # sorted_indices_to_remove[..., 0] = 0
177
+ # # Scatter the mask back to the original indices
178
+ # indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
179
+ # logits[indices_to_remove] = -float('Inf')
180
+
181
+ # # Compute softmax probabilities
182
+ # probs = F.softmax(logits, dim=-1)
183
+ # # Flatten probabilities to (batch_size * sequence_length, vocab_size)
184
+ # flat_probs = probs.view(-1, probs.size(-1))
185
+ # # Sample from the distribution
186
+ # sampled = torch.multinomial(flat_probs, num_samples=1)
187
+ # # Reshape to original shape except for the last dimension
188
+ # sampled = sampled.view(*logits.shape[:-1])
189
+ # return sampled
190
+
191
+ # @torch.no_grad()
192
+ # def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
193
+ # top_k=500, top_p=0.5, seed=None):
194
+ # """
195
+ # Parameters:
196
+ # seq: torch.Tensor of shape (b, t, n_freq_bins)
197
+ # Input cochleagram to use for generation
198
+ # n_tokens: int
199
+ # Number of time bins to predict
200
+ # temp: float
201
+ # Temperature for sampling logits
202
+ # seed: int
203
+ # Random seed for sampling
204
+
205
+ # Returns:
206
+ # pred_coch: torch.Tensor of shape (b, t, n_freq_bins)
207
+ # The predicted cochleagram
208
+ # all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins)
209
+ # The logits for each time step
210
+ # all_embs: (optional if return_embs is not None) list of torch.Tensor
211
+ # The embeddings for each transformer block
212
+ # """
213
+
214
+ # # Set seed if provided
215
+ # if seed is not None:
216
+ # random.seed(seed)
217
+ # np.random.seed(seed)
218
+ # torch.manual_seed(seed)
219
+
220
+ # # make a list of logits to return
221
+ # all_logits = []
222
+ # device = seq.device
223
+
224
+ # # grab shape of the cochleagram
225
+ # b, t = seq.size()
226
+
227
+ # # TODO: double check this works then delete the block bellow:
228
+ # # pass the given input through the model to get the predictions and cache
229
+ # # the k and v values for each transformer block in the process
230
+ # # pos = torch.arange(0, t, dtype=torch.long, device=device)
231
+ # # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
232
+ # # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
233
+ # # x = self.transformer.drop(tok_emb + pos_emb)
234
+
235
+ # #### Embed conditioning sequence into KV cache
236
+
237
+ # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
238
+ # # if wpe exists in self.transformer apply leanred positional embedding
239
+ # if hasattr(self.transformer, 'wpe'):
240
+ # pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
241
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
242
+ # x = self.transformer.drop(tok_emb + pos_emb)
243
+ # else:
244
+ # x = self.transformer.drop(tok_emb)
245
+
246
+ # # Initialize list to store k and v for each transformer block
247
+ # k_list = []
248
+ # v_list = []
249
+ # for block_idx, block in enumerate(self.transformer.h):
250
+ # # Pass through the transformer block, and store k and v
251
+ # x, k, v = block(x, pos=pos, return_kv=True)
252
+ # k_list.append(k)
253
+ # v_list.append(v)
254
+ # # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head)
255
+ # k_cache = torch.stack(k_list, dim=0)
256
+ # v_cache = torch.stack(v_list, dim=0)
257
+ # # Pass through the final layer norm
258
+ # x = self.transformer.ln_f(x)
259
+
260
+ # # First prediction of the model is the decoding of the last time bin
261
+ # logits = self.coch_head(x[:, [-1]])
262
+ # predictions = [self.sample_logits(logits, temperature=temp)]
263
+ # all_logits.append(logits)
264
+
265
+ # ### Predict future tokens
266
+
267
+ # # Now we pass the last time bin through the model to predict the next time bin
268
+ # # we subtract 1 from max_new_tokens because we already predicted the first time bin
269
+ # # using the last embedding of the input
270
+ # for i in range(n_tokens-1):
271
+
272
+ # # TODO: double check this works then delete the block bellow:
273
+ # # # Get the emb and pos embedding of just the last token
274
+ # # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
275
+ # # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
276
+ # # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
277
+ # # x = self.transformer.drop(tok_emb + pos_emb)
278
+
279
+ # # Get the emb and pos embedding of just the last token
280
+ # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
281
+ # # if wpe exists in self.transformer apply leanred positional embedding
282
+ # if hasattr(self.transformer, 'wpe'):
283
+ # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
284
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
285
+ # x = self.transformer.drop(tok_emb + pos_emb)
286
+ # else:
287
+ # x = self.transformer.drop(tok_emb)
288
+
289
+ # # Pass through transformer block
290
+ # k_list = []
291
+ # v_list = []
292
+ # for block_idx, block in enumerate(self.transformer.h):
293
+ # x, k, v = block(x, pos=pos, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
294
+ # k_list.append(k)
295
+ # v_list.append(v)
296
+ # x = self.transformer.ln_f(x)
297
+ # # create the cache with the new embeddings
298
+ # k_cache = torch.stack(k_list, dim=0)
299
+ # v_cache = torch.stack(v_list, dim=0)
300
+ # # predict next time bin
301
+ # logits = self.coch_head(x)
302
+ # predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
303
+ # print(f"logits {logits.argmax()}")
304
+ # lk
305
+ # all_logits.append(logits)
306
+
307
+ # pred_coch = torch.cat(predictions, dim=1)
308
+ # all_logits = torch.cat(all_logits, dim=1)
309
+
310
+ # return pred_coch, all_logits
311
+
312
+
313
+ # def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
314
+ # # start with all of the candidate parameters
315
+ # param_dict = {pn: p for pn, p in self.named_parameters()}
316
+ # # filter out those that do not require grad
317
+ # param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
318
+ # # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
319
+ # # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
320
+ # decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
321
+ # nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
322
+ # optim_groups = [
323
+ # {'params': decay_params, 'weight_decay': weight_decay},
324
+ # {'params': nodecay_params, 'weight_decay': 0.0}
325
+ # ]
326
+ # num_decay_params = sum(p.numel() for p in decay_params)
327
+ # num_nodecay_params = sum(p.numel() for p in nodecay_params)
328
+ # print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
329
+ # print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
330
+ # # Create AdamW optimizer and use the fused version if it is available
331
+ # fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
332
+ # use_fused = fused_available and device_type == 'cuda'
333
+ # extra_args = dict(fused=True) if use_fused else dict()
334
+ # optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
335
+ # print(f"using fused AdamW: {use_fused}")
336
+
337
+ # return optimizer
338
+
339
+ # def estimate_mfu(self, fwdbwd_per_iter, T, dt, gpu_type='A40'):
340
+ # """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
341
+ # # first estimate the number of flops we do per iteration.
342
+ # # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
343
+ # N = self.unsharded_param_count
344
+ # cfg = self.config
345
+ # L, H, Q = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head
346
+ # # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
347
+ # flops_per_token = 6*N + 12*L*H*Q*T
348
+ # flops_per_fwdbwd = flops_per_token * T
349
+ # flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
350
+ # # express our flops throughput as ratio of A100 bfloat16 peak flops
351
+ # flops_achieved = flops_per_iter * (1.0/dt) # per second
352
+
353
+ # # grab promised flops based on GPU type
354
+ # if gpu_type == 'A40':
355
+ # flops_promised = 149.7e12 # A40 GPU bfloat16 peak flops is 149.7 TFLOPS
356
+ # elif gpu_type == 'A100':
357
+ # flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
358
+ # elif gpu_type == 'H100':
359
+ # flops_promised = 756e12 # H100 GPU bfloat16 peak flops is 756 TFLOPS
360
+ # elif gpu_type == 'TPUv4':
361
+ # flops_promised = 275e12
362
+ # elif gpu_type == 'TPUv5e':
363
+ # flops_promised = 197e12
364
+
365
+ # mfu = flops_achieved / flops_promised
366
+ # return mfu
367
+
368
+
369
+ # #########################################################
370
+ # ##### Layer Definitions #####
371
+ # #########################################################
372
+
373
+
374
+ # class Block(nn.Module):
375
+
376
+ # def __init__(self, config):
377
+ # super().__init__()
378
+ # self.attn = CausalSelfAttention(config)
379
+ # self.mlp = MLP(config)
380
+ # self.attn_scale = 1.0 # (1 / (2 * config.n_layer)**0.5)
381
+ # self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
382
+ # self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
383
+
384
+ # def forward(self, x, pos=None, return_kv=False, k_cache=None, v_cache=None):
385
+ # # If we are given a key and value cache, we will use the pre-computed values to minimize
386
+ # # the computation cost
387
+ # if k_cache is not None and v_cache is not None:
388
+ # # Pass the key and value cache to the attention layer, obtain new key and value caches
389
+ # x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), pos=pos, k_cache=k_cache, v_cache=v_cache)
390
+ # x = x + x_attn
391
+ # x = x + self.mlp(self.norm2(x))
392
+ # return x, k, v
393
+ # # We might want to encode the caches of a whole block of keys and values at once using the
394
+ # # fast flash attention impelmentation while still returning the key and value caches
395
+ # elif return_kv:
396
+ # # Pass the key and value cache to the attention layer, obtain new key and value caches
397
+ # x_attn, k, v = self.attn(self.norm1(x), return_kv=True)
398
+ # x = x + x_attn
399
+ # x = x + self.mlp(self.norm2(x))
400
+ # return x, k, v
401
+
402
+ # x = x + self.attn_scale * self.attn(self.norm1(x))
403
+ # x = x + self.mlp(self.norm2(x))
404
+ # return x
405
+
406
+
407
+ # class CausalSelfAttention(nn.Module):
408
+
409
+ # def __init__(self, config):
410
+ # super().__init__()
411
+ # self.n_head = config.n_head
412
+ # self.n_embd = config.n_embd
413
+ # self.head_dim = self.n_embd // self.n_head
414
+ # assert self.n_embd % self.n_head == 0
415
+ # # key, query, value projections for all heads, but in a batch
416
+ # self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
417
+ # # output projection
418
+ # self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
419
+
420
+ # rope_theta = 500000
421
+ # if hasattr(config, 'rope_theta') and config.rope_theta is not None:
422
+ # rope_theta = config.rope_theta
423
+
424
+ # self.rotary = Rotary(self.head_dim, base=rope_theta)
425
+
426
+ # if hasattr(config, 'use_rope') and not config.use_rope:
427
+ # self.rotary = None
428
+
429
+ # def forward(self, x, return_kv=False, return_attn_maps=False):
430
+
431
+ # B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
432
+ # # calculate query, key, values for all heads in batch and move head forward to be the batch dim
433
+ # qkv = self.c_attn(x)
434
+ # q, k, v = qkv.split(self.n_embd, dim=2)
435
+ # k = k.view(B, T, self.n_head, self.head_dim)
436
+ # q = q.view(B, T, self.n_head, self.head_dim)
437
+ # v = v.view(B, T, self.n_head, self.head_dim)
438
+
439
+ # if self.rotary is not None:
440
+ # cos, sin = self.rotary(q)
441
+ # q = apply_rotary_emb(q, cos, sin)
442
+ # k = apply_rotary_emb(k, cos, sin)
443
+
444
+ # if not return_kv and not return_attn_maps:
445
+ # y = F.scaled_dot_product_attention(
446
+ # q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
447
+ # is_causal=True)
448
+ # else:
449
+ # # manual implementation of attention
450
+ # q = q.transpose(1, 2)
451
+ # k = k.transpose(1, 2)
452
+ # v = v.transpose(1, 2)
453
+ # att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
454
+ # mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
455
+ # mask = mask.view(1, 1, T, T)
456
+ # masked_att = att.masked_fill(mask, float('-inf'))
457
+ # # upcast to float32 for numerical stability, as per llama implementation
458
+ # masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
459
+ # # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
460
+ # y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
461
+
462
+ # y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
463
+
464
+ # # output projection
465
+ # y = self.c_proj(y)
466
+
467
+ # # return attention maps if requested
468
+ # if return_attn_maps:
469
+ # return y, F.softmax(att, dim=-1)
470
+
471
+ # # return key and value caches if requested
472
+ # if return_kv:
473
+ # return y, k, v
474
+
475
+ # return y
476
+
477
+
478
+ # def kv_cache_forward(
479
+ # self,
480
+ # x: torch.Tensor,
481
+ # pos: torch.Tensor,
482
+ # k_cache: torch.Tensor | None = None,
483
+ # v_cache: torch.Tensor | None = None,
484
+ # return_attn_maps: bool = False,
485
+ # ):
486
+ # B, T, C = x.size()
487
+
488
+ # q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
489
+ # q = q.view(B, T, self.n_head, self.head_dim) # (B, T, n_head, d)
490
+ # k = k.view(B, T, self.n_head, self.head_dim)
491
+ # v = v.view(B, T, self.n_head, self.head_dim)
492
+
493
+ # if self.rotary is not None:
494
+ # cos, sin = self.rotary(q, t=pos) # cos/sin match (B, T, n_head, d)
495
+ # q = apply_rotary_emb(q, cos, sin)
496
+ # k = apply_rotary_emb(k, cos, sin)
497
+
498
+ # q = q.transpose(1, 2) # (B, n_head, T, d)
499
+ # k = k.transpose(1, 2)
500
+ # v = v.transpose(1, 2)
501
+
502
+ # if k_cache is not None:
503
+ # k = torch.cat([k_cache, k], dim=2) # time dim grows
504
+ # if v_cache is not None:
505
+ # v = torch.cat([v_cache, v], dim=2)
506
+
507
+ # if not return_attn_maps:
508
+ # y = F.scaled_dot_product_attention(
509
+ # q, k, v,
510
+ # is_causal=True)
511
+ # else:
512
+ # # manual implementation of attention
513
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
514
+ # att = F.softmax(att, dim=-1)
515
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
516
+
517
+ # y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
518
+
519
+ # # output projection
520
+ # y = self.c_proj(y)
521
+
522
+ # y = y.transpose(1, 2).contiguous().view(B, T, C)
523
+ # y = self.c_proj(y)
524
+
525
+ # return y, k, v
526
+
527
+
528
+ # class MLP(nn.Module):
529
+
530
+ # def __init__(self, config):
531
+ # super().__init__()
532
+ # self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
533
+ # self.gelu = nn.GELU()
534
+ # self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
535
+ # self.dropout = nn.Dropout(config.dropout)
536
+
537
+ # def forward(self, x):
538
+ # x = self.c_fc(x)
539
+ # x = self.gelu(x)
540
+ # x = self.c_proj(x)
541
+ # x = self.dropout(x)
542
+ # return x
543
+
544
+
545
+ # class Rotary(torch.nn.Module):
546
+ # def __init__(self, dim, base=500000, learned=True):
547
+ # super().__init__()
548
+ # # Compute the base inverse frequencies as before.
549
+ # inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
550
+ # # If learned is True, register as a parameter; otherwise, as a buffer.
551
+ # if learned:
552
+ # # Initialize randomly and register as a parameter.
553
+ # self.inv_freq = torch.nn.Parameter(inv_freq)
554
+ # nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
555
+ # else:
556
+ # self.register_buffer("inv_freq", inv_freq)
557
+ # self.learned = learned # (optional) Save the flag if needed later
558
+
559
+ # def forward(self, x, t=None):
560
+ # seq_len = x.shape[1]
561
+ # if t is None:
562
+ # # Create a tensor of positions.
563
+ # t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
564
+ # # Outer product to compute angles; this uses the (possibly learnable) frequencies.
565
+ # freqs = torch.outer(t, self.inv_freq).to(x.device)
566
+ # cos_cached = freqs.cos()
567
+ # sin_cached = freqs.sin()
568
+ # return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
569
+
570
+ # def apply_rotary_emb(x, cos, sin):
571
+ # assert x.ndim == 4 # multihead attention expected
572
+ # d = x.shape[3] // 2
573
+ # x1 = x[..., :d]
574
+ # x2 = x[..., d:]
575
+ # y1 = x1 * cos + x2 * sin
576
+ # y2 = x1 * (-sin) + x2 * cos
577
+ # return torch.cat([y1, y2], dim=3)
578
+
579
+
580
+ # class RMSNorm(nn.Module):
581
+ # """ Root Mean Square Normalization """
582
+ # def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
583
+ # super().__init__()
584
+ # self.eps = eps
585
+ # self.weight = nn.Parameter(torch.ones(dim)) if weight else None
586
+
587
+ # def _norm(self, x):
588
+ # return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
589
+
590
+ # def forward(self, x):
591
+ # output = self._norm(x.float()).type_as(x)
592
+ # if self.weight is not None:
593
+ # return output * self.weight
594
+ # return output
595
+
596
+
597
  """
598
  AuriStream sequence model definition.
599
  """
 
844
  v_list = []
845
  for block_idx, block in enumerate(self.transformer.h):
846
  # Pass through the transformer block, and store k and v
847
+ x, k, v = block(x, return_kv=True)
848
  k_list.append(k)
849
  v_list.append(v)
850
  # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head)
 
886
  k_list = []
887
  v_list = []
888
  for block_idx, block in enumerate(self.transformer.h):
889
+ x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
890
  k_list.append(k)
891
  v_list.append(v)
892
  x = self.transformer.ln_f(x)
 
896
  # predict next time bin
897
  logits = self.coch_head(x)
898
  predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
 
 
899
  all_logits.append(logits)
900
 
901
  pred_coch = torch.cat(predictions, dim=1)
 
975
  self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
976
  self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
977
 
978
+ def forward(self, x, return_kv=False, k_cache=None, v_cache=None):
979
  # If we are given a key and value cache, we will use the pre-computed values to minimize
980
  # the computation cost
981
  if k_cache is not None and v_cache is not None:
982
  # Pass the key and value cache to the attention layer, obtain new key and value caches
983
+ x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache)
984
  x = x + x_attn
985
  x = x + self.mlp(self.norm2(x))
986
  return x, k, v
 
1068
 
1069
  return y
1070
 
1071
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
1072
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1073
 
1074
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1075
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1076
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
1077
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
1078
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1079
 
1080
+ # append cached keys and values with new keys and values
1081
  if k_cache is not None:
1082
+ k = torch.cat((k_cache, k), dim=2)
1083
  if v_cache is not None:
1084
+ v = torch.cat((v_cache, v), dim=2)
1085
 
1086
+ # manual implementation of attention
1087
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
1088
+ att = F.softmax(att, dim=-1)
1089
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
 
 
 
 
 
 
 
 
 
 
1090
 
1091
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
1092
+
1093
+ # output projection
1094
  y = self.c_proj(y)
1095
 
1096
  return y, k, v
 
1127
  self.register_buffer("inv_freq", inv_freq)
1128
  self.learned = learned # (optional) Save the flag if needed later
1129
 
1130
+ def forward(self, x):
1131
  seq_len = x.shape[1]
1132
+ # Create a tensor of positions.
1133
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
 
1134
  # Outer product to compute angles; this uses the (possibly learnable) frequencies.
1135
  freqs = torch.outer(t, self.inv_freq).to(x.device)
1136
  cos_cached = freqs.cos()
 
1161
  output = self._norm(x.float()).type_as(x)
1162
  if self.weight is not None:
1163
  return output * self.weight
1164
+ return output