cache
Browse files- README.md +2 -2
- app.py +4 -6
- audiocraft.py +6 -5
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
title: Audiogen
|
| 3 |
emoji: 🍍
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.41.1
|
| 8 |
app_file: app.py
|
|
|
|
| 1 |
---
|
| 2 |
title: Audiogen
|
| 3 |
emoji: 🍍
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.41.1
|
| 8 |
app_file: app.py
|
app.py
CHANGED
|
@@ -35,7 +35,7 @@ def audionar_tts(text=None,
|
|
| 35 |
lang='Romanian',
|
| 36 |
soundscape='frogs',
|
| 37 |
max_tokens=24,
|
| 38 |
-
cache_lim=
|
| 39 |
|
| 40 |
# https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
|
| 41 |
|
|
@@ -98,12 +98,11 @@ def audionar_tts(text=None,
|
|
| 98 |
speech_duration_secs = len(x) / 16000
|
| 99 |
target_duration = max(speech_duration_secs + 0.74, 2.0)
|
| 100 |
# Sink Attn
|
| 101 |
-
audiogen.cache_lim = min( max(0, int(cache_lim)), 2000)
|
| 102 |
-
|
| 103 |
background_audio = audiogen.generate(
|
| 104 |
soundscape[:64], # to have shape of cross attention not grow large of T5 Num tokens
|
| 105 |
duration=target_duration,
|
| 106 |
-
max_tokens=min( max(7, int(max_tokens)), 288 ) # limit sounds tokens (clone beyond)
|
|
|
|
| 107 |
).numpy()
|
| 108 |
|
| 109 |
# PAD
|
|
@@ -140,7 +139,6 @@ def audionar_tts(text=None,
|
|
| 140 |
soundfile.write(wavfile, final_audio, 16000) # soundfile needs [time, channels]
|
| 141 |
return wavfile
|
| 142 |
|
| 143 |
-
|
| 144 |
# TTS
|
| 145 |
|
| 146 |
|
|
@@ -165,7 +163,7 @@ with gr.Blocks() as demo:
|
|
| 165 |
)
|
| 166 |
cache_lim = gr.Number(
|
| 167 |
label="Flush kv",
|
| 168 |
-
value=
|
| 169 |
)
|
| 170 |
n_tokens = gr.Number(
|
| 171 |
label="Tokens",
|
|
|
|
| 35 |
lang='Romanian',
|
| 36 |
soundscape='frogs',
|
| 37 |
max_tokens=24,
|
| 38 |
+
cache_lim=-1):
|
| 39 |
|
| 40 |
# https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
|
| 41 |
|
|
|
|
| 98 |
speech_duration_secs = len(x) / 16000
|
| 99 |
target_duration = max(speech_duration_secs + 0.74, 2.0)
|
| 100 |
# Sink Attn
|
|
|
|
|
|
|
| 101 |
background_audio = audiogen.generate(
|
| 102 |
soundscape[:64], # to have shape of cross attention not grow large of T5 Num tokens
|
| 103 |
duration=target_duration,
|
| 104 |
+
max_tokens=min( max(7, int(max_tokens)), 288 ), # limit sounds tokens (clone beyond)
|
| 105 |
+
cache_lim=min( max(6, int(cache_lim)), 2000),
|
| 106 |
).numpy()
|
| 107 |
|
| 108 |
# PAD
|
|
|
|
| 139 |
soundfile.write(wavfile, final_audio, 16000) # soundfile needs [time, channels]
|
| 140 |
return wavfile
|
| 141 |
|
|
|
|
| 142 |
# TTS
|
| 143 |
|
| 144 |
|
|
|
|
| 163 |
)
|
| 164 |
cache_lim = gr.Number(
|
| 165 |
label="Flush kv",
|
| 166 |
+
value=71,
|
| 167 |
)
|
| 168 |
n_tokens = gr.Number(
|
| 169 |
label="Tokens",
|
audiocraft.py
CHANGED
|
@@ -63,6 +63,7 @@ class AudioGen(torch.nn.Module):
|
|
| 63 |
prompt='dogs mewo',
|
| 64 |
duration=2.24, # seconds of audio
|
| 65 |
max_tokens=24, # actual num of A/R iterations - above is obtained as clone
|
|
|
|
| 66 |
):
|
| 67 |
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
|
| 68 |
n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
|
|
@@ -70,7 +71,8 @@ class AudioGen(torch.nn.Module):
|
|
| 70 |
with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
|
| 71 |
gen_tokens = self.lm.generate(
|
| 72 |
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
|
| 73 |
-
max_tokens=max_tokens
|
|
|
|
| 74 |
|
| 75 |
# OOM if vocode all tokens
|
| 76 |
x = []
|
|
@@ -435,7 +437,6 @@ class LMModel(nn.Module):
|
|
| 435 |
self.card = card # 2048
|
| 436 |
self.n_draw = 1 # draw > 1 tokens of different CFG scale
|
| 437 |
# batch size > 1 is slower from n_draw as calls transformer on larger batch
|
| 438 |
-
self.cache_lim = 71
|
| 439 |
self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
|
| 440 |
self.transformer = StreamingTransformer()
|
| 441 |
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
|
|
@@ -475,8 +476,8 @@ class LMModel(nn.Module):
|
|
| 475 |
@torch.no_grad()
|
| 476 |
def generate(self,
|
| 477 |
max_tokens=None,
|
| 478 |
-
text_condition=None
|
| 479 |
-
):
|
| 480 |
self.transformer._flush() # perhaps long kv cache has been filled on previous call for unrelated sounds
|
| 481 |
x = self.t5(text_condition)
|
| 482 |
bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
|
|
@@ -540,7 +541,7 @@ class LMModel(nn.Module):
|
|
| 540 |
|
| 541 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
| 542 |
# Sink Attn
|
| 543 |
-
if (offset > 0) and (offset %
|
| 544 |
n_preserve = 4
|
| 545 |
self.transformer._flush(n_preserve=n_preserve)
|
| 546 |
cache_position = n_preserve
|
|
|
|
| 63 |
prompt='dogs mewo',
|
| 64 |
duration=2.24, # seconds of audio
|
| 65 |
max_tokens=24, # actual num of A/R iterations - above is obtained as clone
|
| 66 |
+
cache_lim=71,
|
| 67 |
):
|
| 68 |
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
|
| 69 |
n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
|
|
|
|
| 71 |
with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
|
| 72 |
gen_tokens = self.lm.generate(
|
| 73 |
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
|
| 74 |
+
max_tokens=max_tokens,
|
| 75 |
+
cache_lim=cache_lim)
|
| 76 |
|
| 77 |
# OOM if vocode all tokens
|
| 78 |
x = []
|
|
|
|
| 437 |
self.card = card # 2048
|
| 438 |
self.n_draw = 1 # draw > 1 tokens of different CFG scale
|
| 439 |
# batch size > 1 is slower from n_draw as calls transformer on larger batch
|
|
|
|
| 440 |
self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
|
| 441 |
self.transformer = StreamingTransformer()
|
| 442 |
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
|
|
|
|
| 476 |
@torch.no_grad()
|
| 477 |
def generate(self,
|
| 478 |
max_tokens=None,
|
| 479 |
+
text_condition=None,
|
| 480 |
+
cache_lim=71):
|
| 481 |
self.transformer._flush() # perhaps long kv cache has been filled on previous call for unrelated sounds
|
| 482 |
x = self.t5(text_condition)
|
| 483 |
bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
|
|
|
|
| 541 |
|
| 542 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
| 543 |
# Sink Attn
|
| 544 |
+
if (offset > 0) and (offset % cache_lim) == 0:
|
| 545 |
n_preserve = 4
|
| 546 |
self.transformer._flush(n_preserve=n_preserve)
|
| 547 |
cache_position = n_preserve
|