Dionyssos commited on
Commit
2f811de
·
1 Parent(s): cecf050
Files changed (1) hide show
  1. audiocraft.py +3 -3
audiocraft.py CHANGED
@@ -66,7 +66,7 @@ class AudioGen(torch.nn.Module):
66
  ):
67
  torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
68
  n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
69
- print(f'{n_draw=} {duration=}seconds < {prompt=}')
70
  with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
71
  gen_tokens = self.lm.generate(
72
  text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
@@ -540,7 +540,7 @@ class LMModel(nn.Module):
540
 
541
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
542
  # Sink Attn
543
- if (offset > 0) and (offset % 476) == 0:
544
  n_preserve = 4
545
  self.transformer._flush(n_preserve=n_preserve)
546
  cache_position = n_preserve
@@ -640,7 +640,7 @@ class StreamingMultiheadAttention(nn.Module):
640
 
641
  k = self.k_history
642
  v = self.v_history
643
- print(q.shape, k.shape, v.shape,'Self Atts')
644
  # -> kv CACHE ONLY APPLIES if not self.cross_attention
645
 
646
  x = torch.nn.functional.scaled_dot_product_attention(
 
66
  ):
67
  torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
68
  n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
69
+ print(f'{n_draw=} {duration=}seconds < {prompt=} | {max_tokens=}')
70
  with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
71
  gen_tokens = self.lm.generate(
72
  text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
 
540
 
541
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
542
  # Sink Attn
543
+ if (offset > 0) and (offset % 71) == 0:
544
  n_preserve = 4
545
  self.transformer._flush(n_preserve=n_preserve)
546
  cache_position = n_preserve
 
640
 
641
  k = self.k_history
642
  v = self.v_history
643
+
644
  # -> kv CACHE ONLY APPLIES if not self.cross_attention
645
 
646
  x = torch.nn.functional.scaled_dot_product_attention(