Dionyssos commited on
Commit
a7e2983
·
1 Parent(s): 0f06964

user defines num tokens

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. audiocraft.py +9 -8
app.py CHANGED
@@ -39,7 +39,7 @@ language_names = ['Ancient greek',
39
  def audionar_tts(text=None,
40
  lang='Romanian',
41
  soundscape='frogs',
42
- cache_lim=24):
43
 
44
  # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
45
 
@@ -119,7 +119,7 @@ def audionar_tts(text=None,
119
  background_audio = audiogen.generate(
120
  soundscape,
121
  duration=target_duration,
122
- cache_lim=max(4, int(cache_lim)) # at least allow 10 A/R stEps
123
  ).numpy()
124
 
125
  # PAD
@@ -272,7 +272,7 @@ with gr.Blocks(theme='huggingface') as demo:
272
  label="AudioGen Txt"
273
  )
274
  kv_input = gr.Number(
275
- label="Diversy",
276
  value=24,
277
  )
278
  generate_button = gr.Button("Generate Audio", variant="primary")
 
39
  def audionar_tts(text=None,
40
  lang='Romanian',
41
  soundscape='frogs',
42
+ max_tokens=24):
43
 
44
  # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
45
 
 
119
  background_audio = audiogen.generate(
120
  soundscape,
121
  duration=target_duration,
122
+ max_tokens=max(4, int(max_tokens)) # at least allow 10 A/R stEps
123
  ).numpy()
124
 
125
  # PAD
 
272
  label="AudioGen Txt"
273
  )
274
  kv_input = gr.Number(
275
+ label="Num Tokens",
276
  value=24,
277
  )
278
  generate_button = gr.Button("Generate Audio", variant="primary")
audiocraft.py CHANGED
@@ -62,15 +62,14 @@ class AudioGen(torch.nn.Module):
62
  def generate(self,
63
  prompt='dogs mewo',
64
  duration=2.24, # seconds of audio
65
- cache_lim=71, # flush kv cache after cache_lim tok
66
  ):
67
  torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
68
- self.lm.cache_lim = cache_lim
69
- self.lm.n_draw = int(.8 * duration) + 1 # different beam every 0.47 seconds of audio
70
  with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
71
  gen_tokens = self.lm.generate(
72
  text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
73
- max_tokens=max(int(duration / (N_REPEAT * self.lm.n_draw) * 50) + 5, 12))
74
 
75
  # OOM if vocode all tokens
76
  x = []
@@ -81,9 +80,11 @@ class AudioGen(torch.nn.Module):
81
  decoded_chunk = self.compression_model.decode(gen_tokens[:, :, i-7:i+self._chunk_len])
82
 
83
  x.append(decoded_chunk)
84
-
85
  x = torch.cat(x, 2) # [bs, 1, 114000]
86
 
 
 
87
  x = _shift(x) # clone() to have xN
88
 
89
  return x.reshape(-1) #x / (x.abs().max() + 1e-7)
@@ -430,7 +431,6 @@ class LMModel(nn.Module):
430
  dim = 1536
431
  ):
432
  super().__init__()
433
- self.cache_lim = -1
434
  self.t5 = T5()
435
  self.card = card # 2048
436
  self.n_draw = 1 # draw > 1 tokens of different CFG scale
@@ -468,6 +468,7 @@ class LMModel(nn.Module):
468
  # divide large probs with exp(prob) If prob=.001 then 1/exp(1*.001) -> almost by 0 --> exp doesnt really produce (0, Inf)
469
  p = p.argmax(dim=3, keepdim=True) # [bs, 4, n_draw, 24]
470
  tok = ix.gather(dim=3, index=p).to(torch.int64) # [bs, 4, n_draw, 1]
 
471
  return tok[:, :, :, 0].transpose(1, 2) # [bs, n_draw, 4]
472
 
473
  @torch.no_grad()
@@ -537,7 +538,7 @@ class LMModel(nn.Module):
537
 
538
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
539
  # Sink Attn
540
- if (offset > 0) and (offset % self.cache_lim) == 0:
541
  n_preserve = 4
542
  self.transformer._flush(n_preserve=n_preserve)
543
  cache_position = n_preserve
@@ -726,5 +727,5 @@ if __name__ == '__main__':
726
 
727
  import audiofile # pip uninstall flash-attn
728
  model = AudioGen().to('cpu')
729
- x = model.generate(prompt='swims in lake frogs', duration=6.4).cpu().numpy()
730
  audiofile.write('_sound_.wav', x, 16000)
 
62
  def generate(self,
63
  prompt='dogs mewo',
64
  duration=2.24, # seconds of audio
65
+ max_tokens=71, # actual num of A/R iterations - above is obtained as clone
66
  ):
67
  torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
68
+ n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
 
69
  with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
70
  gen_tokens = self.lm.generate(
71
  text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
72
+ max_tokens=max_tokens)
73
 
74
  # OOM if vocode all tokens
75
  x = []
 
80
  decoded_chunk = self.compression_model.decode(gen_tokens[:, :, i-7:i+self._chunk_len])
81
 
82
  x.append(decoded_chunk)
83
+
84
  x = torch.cat(x, 2) # [bs, 1, 114000]
85
 
86
+ x = x.repeat(1, 1, n_draw)
87
+
88
  x = _shift(x) # clone() to have xN
89
 
90
  return x.reshape(-1) #x / (x.abs().max() + 1e-7)
 
431
  dim = 1536
432
  ):
433
  super().__init__()
 
434
  self.t5 = T5()
435
  self.card = card # 2048
436
  self.n_draw = 1 # draw > 1 tokens of different CFG scale
 
468
  # divide large probs with exp(prob) If prob=.001 then 1/exp(1*.001) -> almost by 0 --> exp doesnt really produce (0, Inf)
469
  p = p.argmax(dim=3, keepdim=True) # [bs, 4, n_draw, 24]
470
  tok = ix.gather(dim=3, index=p).to(torch.int64) # [bs, 4, n_draw, 1]
471
+
472
  return tok[:, :, :, 0].transpose(1, 2) # [bs, n_draw, 4]
473
 
474
  @torch.no_grad()
 
538
 
539
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
540
  # Sink Attn
541
+ if (offset > 0) and (offset % 476) == 0:
542
  n_preserve = 4
543
  self.transformer._flush(n_preserve=n_preserve)
544
  cache_position = n_preserve
 
727
 
728
  import audiofile # pip uninstall flash-attn
729
  model = AudioGen().to('cpu')
730
+ x = model.generate(prompt='swims in lake frogs', duration=56.4, max_tokens=24).cpu().numpy()
731
  audiofile.write('_sound_.wav', x, 16000)