klemenk commited on
Commit
586a353
·
verified ·
1 Parent(s): bb957c9

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +101 -16
modeling_auristream.py CHANGED
@@ -165,7 +165,6 @@ class AuriStream(PreTrainedModel):
165
  )
166
  if output_logits:
167
  all_logits.append(future_logits)
168
- loss = loss / (len(self.future_heads) + 1)
169
 
170
  if return_dict:
171
  if output_logits:
@@ -195,12 +194,47 @@ class AuriStream(PreTrainedModel):
195
  return model_output
196
 
197
  return logits, loss
 
 
 
 
 
 
 
 
 
 
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  return logits, None
200
 
201
-
202
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
203
- top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
204
  """
205
  Samples an integer from the distribution of logits
206
  Parameters:
@@ -252,7 +286,7 @@ class AuriStream(PreTrainedModel):
252
 
253
  @torch.no_grad()
254
  def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
255
- top_k=500, top_p=0.5, seed=None):
256
  """
257
  Parameters:
258
  seq: torch.Tensor of shape (b, t, n_freq_bins)
@@ -321,7 +355,7 @@ class AuriStream(PreTrainedModel):
321
 
322
  # First prediction of the model is the decoding of the last time bin
323
  logits = self.coch_head(x[:, [-1]])
324
- predictions = [self.sample_logits(logits, temperature=temp)]
325
  all_logits.append(logits)
326
 
327
  ### Predict future tokens
@@ -534,31 +568,82 @@ class CausalSelfAttention(nn.Module):
534
 
535
  return y
536
 
537
- def kv_cache_forward(self, x, k_cache=None, v_cache=None):
538
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
 
 
 
 
 
 
539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  # calculate query, key, values for all heads in batch and move head forward to be the batch dim
541
  q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
542
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
543
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
544
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
545
-
546
- # append cached keys and values with new keys and values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  if k_cache is not None:
548
  k = torch.cat((k_cache, k), dim=2)
549
  if v_cache is not None:
550
  v = torch.cat((v_cache, v), dim=2)
551
-
552
  # manual implementation of attention
553
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
554
  att = F.softmax(att, dim=-1)
555
  y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
556
-
557
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
558
-
559
  # output projection
560
  y = self.c_proj(y)
561
-
562
  return y, k, v
563
 
564
 
 
165
  )
166
  if output_logits:
167
  all_logits.append(future_logits)
 
168
 
169
  if return_dict:
170
  if output_logits:
 
194
  return model_output
195
 
196
  return logits, loss
197
+
198
+ else:
199
+ if output_logits:
200
+ all_logits = [logits]
201
+
202
+ # future multi-step heads (unchanged)
203
+ if self.future_heads is not None:
204
+ for i, head in enumerate(self.future_heads):
205
+ future_logits = head(x[:, :-(i + 1)])
206
+ if output_logits:
207
+ all_logits.append(future_logits)
208
 
209
+ if return_dict:
210
+ if output_logits:
211
+ if output_hidden_states:
212
+ model_output = CausalLMOutput(
213
+ logits=all_logits,
214
+ hidden_states=hs_to_return,
215
+ )
216
+ else:
217
+ model_output = CausalLMOutput(
218
+ logits=all_logits,
219
+ )
220
+ else:
221
+ if output_hidden_states:
222
+ model_output = CausalLMOutput(
223
+ logits=logits,
224
+ hidden_states=hs_to_return,
225
+ )
226
+ else:
227
+ model_output = CausalLMOutput(
228
+ logits=logits,
229
+ )
230
+ return model_output
231
+
232
+ return logits, loss
233
+
234
  return logits, None
235
 
 
236
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
237
+ top_k: int = None, top_p: float = None) -> torch.LongTensor:
238
  """
239
  Samples an integer from the distribution of logits
240
  Parameters:
 
286
 
287
  @torch.no_grad()
288
  def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
289
+ top_k=None, top_p=None, seed=None):
290
  """
291
  Parameters:
292
  seq: torch.Tensor of shape (b, t, n_freq_bins)
 
355
 
356
  # First prediction of the model is the decoding of the last time bin
357
  logits = self.coch_head(x[:, [-1]])
358
+ predictions = [self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)]
359
  all_logits.append(logits)
360
 
361
  ### Predict future tokens
 
568
 
569
  return y
570
 
571
+ # def kv_cache_forward(self, x, k_cache=None, v_cache=None):
572
+ # B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
573
+
574
+ # # calculate query, key, values for all heads in batch and move head forward to be the batch dim
575
+ # q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
576
+ # k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
577
+ # q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
578
+ # v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
579
 
580
+ # # append cached keys and values with new keys and values
581
+ # if k_cache is not None:
582
+ # k = torch.cat((k_cache, k), dim=2)
583
+ # if v_cache is not None:
584
+ # v = torch.cat((v_cache, v), dim=2)
585
+
586
+ # if self.rotary is not None:
587
+ # cos, sin = self.rotary(q)
588
+ # q = apply_rotary_emb(q, cos, sin)
589
+ # k = apply_rotary_emb(k, cos, sin)
590
+
591
+ # # manual implementation of attention
592
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
593
+ # att = F.softmax(att, dim=-1)
594
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
595
+
596
+ # y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
597
+
598
+ # # output projection
599
+ # y = self.c_proj(y)
600
+
601
+ # return y, k, v
602
+
603
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
604
+ B, T, C = x.size() # T=1 for single new token
605
+
606
  # calculate query, key, values for all heads in batch and move head forward to be the batch dim
607
  q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
608
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, 1, hs)
609
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, 1, hs)
610
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, 1, hs)
611
+
612
+ # Apply RoPE BEFORE concatenation, using correct absolute position
613
+ if self.rotary is not None:
614
+ # Determine the position of the new token
615
+ cache_len = k_cache.shape[2] if k_cache is not None else 0
616
+
617
+ # Create a dummy tensor with the correct sequence position for rotary computation
618
+ # We need shape (B, cache_len + 1, nh, hs) but only use the last position
619
+ dummy = torch.zeros(B, cache_len + T, self.n_head, self.head_dim,
620
+ device=q.device, dtype=q.dtype)
621
+ cos, sin = self.rotary(dummy)
622
+
623
+ # Extract rotary embeddings for only the new token position
624
+ cos = cos[:, cache_len:cache_len+T, :, :]
625
+ sin = sin[:, cache_len:cache_len+T, :, :]
626
+
627
+ # Apply rotary embeddings to new q and k only
628
+ q = apply_rotary_emb(q, cos, sin)
629
+ k = apply_rotary_emb(k, cos, sin)
630
+
631
+ # NOW concatenate with cache (cached keys already have correct RoPE applied)
632
  if k_cache is not None:
633
  k = torch.cat((k_cache, k), dim=2)
634
  if v_cache is not None:
635
  v = torch.cat((v_cache, v), dim=2)
636
+
637
  # manual implementation of attention
638
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
639
  att = F.softmax(att, dim=-1)
640
  y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
641
+
642
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
643
+
644
  # output projection
645
  y = self.c_proj(y)
646
+
647
  return y, k, v
648
 
649