Dionyssos commited on
Commit
4813448
·
1 Parent(s): 52d3c83
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +4 -6
  3. audiocraft.py +6 -5
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Audiogen
3
  emoji: 🍍
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.41.1
8
  app_file: app.py
 
1
  ---
2
  title: Audiogen
3
  emoji: 🍍
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.41.1
8
  app_file: app.py
app.py CHANGED
@@ -35,7 +35,7 @@ def audionar_tts(text=None,
35
  lang='Romanian',
36
  soundscape='frogs',
37
  max_tokens=24,
38
- cache_lim=3):
39
 
40
  # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
41
 
@@ -98,12 +98,11 @@ def audionar_tts(text=None,
98
  speech_duration_secs = len(x) / 16000
99
  target_duration = max(speech_duration_secs + 0.74, 2.0)
100
  # Sink Attn
101
- audiogen.cache_lim = min( max(0, int(cache_lim)), 2000)
102
-
103
  background_audio = audiogen.generate(
104
  soundscape[:64], # to have shape of cross attention not grow large of T5 Num tokens
105
  duration=target_duration,
106
- max_tokens=min( max(7, int(max_tokens)), 288 ) # limit sounds tokens (clone beyond)
 
107
  ).numpy()
108
 
109
  # PAD
@@ -140,7 +139,6 @@ def audionar_tts(text=None,
140
  soundfile.write(wavfile, final_audio, 16000) # soundfile needs [time, channels]
141
  return wavfile
142
 
143
-
144
  # TTS
145
 
146
 
@@ -165,7 +163,7 @@ with gr.Blocks() as demo:
165
  )
166
  cache_lim = gr.Number(
167
  label="Flush kv",
168
- value=24,
169
  )
170
  n_tokens = gr.Number(
171
  label="Tokens",
 
35
  lang='Romanian',
36
  soundscape='frogs',
37
  max_tokens=24,
38
+ cache_lim=-1):
39
 
40
  # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
41
 
 
98
  speech_duration_secs = len(x) / 16000
99
  target_duration = max(speech_duration_secs + 0.74, 2.0)
100
  # Sink Attn
 
 
101
  background_audio = audiogen.generate(
102
  soundscape[:64], # to have shape of cross attention not grow large of T5 Num tokens
103
  duration=target_duration,
104
+ max_tokens=min( max(7, int(max_tokens)), 288 ), # limit sounds tokens (clone beyond)
105
+ cache_lim=min( max(6, int(cache_lim)), 2000),
106
  ).numpy()
107
 
108
  # PAD
 
139
  soundfile.write(wavfile, final_audio, 16000) # soundfile needs [time, channels]
140
  return wavfile
141
 
 
142
  # TTS
143
 
144
 
 
163
  )
164
  cache_lim = gr.Number(
165
  label="Flush kv",
166
+ value=71,
167
  )
168
  n_tokens = gr.Number(
169
  label="Tokens",
audiocraft.py CHANGED
@@ -63,6 +63,7 @@ class AudioGen(torch.nn.Module):
63
  prompt='dogs mewo',
64
  duration=2.24, # seconds of audio
65
  max_tokens=24, # actual num of A/R iterations - above is obtained as clone
 
66
  ):
67
  torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
68
  n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
@@ -70,7 +71,8 @@ class AudioGen(torch.nn.Module):
70
  with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
71
  gen_tokens = self.lm.generate(
72
  text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
73
- max_tokens=max_tokens)
 
74
 
75
  # OOM if vocode all tokens
76
  x = []
@@ -435,7 +437,6 @@ class LMModel(nn.Module):
435
  self.card = card # 2048
436
  self.n_draw = 1 # draw > 1 tokens of different CFG scale
437
  # batch size > 1 is slower from n_draw as calls transformer on larger batch
438
- self.cache_lim = 71
439
  self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
440
  self.transformer = StreamingTransformer()
441
  self.out_norm = nn.LayerNorm(dim, eps=1e-5)
@@ -475,8 +476,8 @@ class LMModel(nn.Module):
475
  @torch.no_grad()
476
  def generate(self,
477
  max_tokens=None,
478
- text_condition=None
479
- ):
480
  self.transformer._flush() # perhaps long kv cache has been filled on previous call for unrelated sounds
481
  x = self.t5(text_condition)
482
  bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
@@ -540,7 +541,7 @@ class LMModel(nn.Module):
540
 
541
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
542
  # Sink Attn
543
- if (offset > 0) and (offset % self.cache_lim) == 0:
544
  n_preserve = 4
545
  self.transformer._flush(n_preserve=n_preserve)
546
  cache_position = n_preserve
 
63
  prompt='dogs mewo',
64
  duration=2.24, # seconds of audio
65
  max_tokens=24, # actual num of A/R iterations - above is obtained as clone
66
+ cache_lim=71,
67
  ):
68
  torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
69
  n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
 
71
  with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
72
  gen_tokens = self.lm.generate(
73
  text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
74
+ max_tokens=max_tokens,
75
+ cache_lim=cache_lim)
76
 
77
  # OOM if vocode all tokens
78
  x = []
 
437
  self.card = card # 2048
438
  self.n_draw = 1 # draw > 1 tokens of different CFG scale
439
  # batch size > 1 is slower from n_draw as calls transformer on larger batch
 
440
  self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
441
  self.transformer = StreamingTransformer()
442
  self.out_norm = nn.LayerNorm(dim, eps=1e-5)
 
476
  @torch.no_grad()
477
  def generate(self,
478
  max_tokens=None,
479
+ text_condition=None,
480
+ cache_lim=71):
481
  self.transformer._flush() # perhaps long kv cache has been filled on previous call for unrelated sounds
482
  x = self.t5(text_condition)
483
  bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
 
541
 
542
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
543
  # Sink Attn
544
+ if (offset > 0) and (offset % cache_lim) == 0:
545
  n_preserve = 4
546
  self.transformer._flush(n_preserve=n_preserve)
547
  cache_position = n_preserve