Spaces:
Running
Running
user defines num tokens
Browse files- app.py +3 -3
- audiocraft.py +9 -8
app.py
CHANGED
|
@@ -39,7 +39,7 @@ language_names = ['Ancient greek',
|
|
| 39 |
def audionar_tts(text=None,
|
| 40 |
lang='Romanian',
|
| 41 |
soundscape='frogs',
|
| 42 |
-
|
| 43 |
|
| 44 |
# https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
|
| 45 |
|
|
@@ -119,7 +119,7 @@ def audionar_tts(text=None,
|
|
| 119 |
background_audio = audiogen.generate(
|
| 120 |
soundscape,
|
| 121 |
duration=target_duration,
|
| 122 |
-
|
| 123 |
).numpy()
|
| 124 |
|
| 125 |
# PAD
|
|
@@ -272,7 +272,7 @@ with gr.Blocks(theme='huggingface') as demo:
|
|
| 272 |
label="AudioGen Txt"
|
| 273 |
)
|
| 274 |
kv_input = gr.Number(
|
| 275 |
-
label="
|
| 276 |
value=24,
|
| 277 |
)
|
| 278 |
generate_button = gr.Button("Generate Audio", variant="primary")
|
|
|
|
| 39 |
def audionar_tts(text=None,
|
| 40 |
lang='Romanian',
|
| 41 |
soundscape='frogs',
|
| 42 |
+
max_tokens=24):
|
| 43 |
|
| 44 |
# https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
|
| 45 |
|
|
|
|
| 119 |
background_audio = audiogen.generate(
|
| 120 |
soundscape,
|
| 121 |
duration=target_duration,
|
| 122 |
+
max_tokens=max(4, int(max_tokens)) # at least allow 10 A/R stEps
|
| 123 |
).numpy()
|
| 124 |
|
| 125 |
# PAD
|
|
|
|
| 272 |
label="AudioGen Txt"
|
| 273 |
)
|
| 274 |
kv_input = gr.Number(
|
| 275 |
+
label="Num Tokens",
|
| 276 |
value=24,
|
| 277 |
)
|
| 278 |
generate_button = gr.Button("Generate Audio", variant="primary")
|
audiocraft.py
CHANGED
|
@@ -62,15 +62,14 @@ class AudioGen(torch.nn.Module):
|
|
| 62 |
def generate(self,
|
| 63 |
prompt='dogs mewo',
|
| 64 |
duration=2.24, # seconds of audio
|
| 65 |
-
|
| 66 |
):
|
| 67 |
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
|
| 68 |
-
|
| 69 |
-
self.lm.n_draw = int(.8 * duration) + 1 # different beam every 0.47 seconds of audio
|
| 70 |
with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
|
| 71 |
gen_tokens = self.lm.generate(
|
| 72 |
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
|
| 73 |
-
max_tokens=
|
| 74 |
|
| 75 |
# OOM if vocode all tokens
|
| 76 |
x = []
|
|
@@ -81,9 +80,11 @@ class AudioGen(torch.nn.Module):
|
|
| 81 |
decoded_chunk = self.compression_model.decode(gen_tokens[:, :, i-7:i+self._chunk_len])
|
| 82 |
|
| 83 |
x.append(decoded_chunk)
|
| 84 |
-
|
| 85 |
x = torch.cat(x, 2) # [bs, 1, 114000]
|
| 86 |
|
|
|
|
|
|
|
| 87 |
x = _shift(x) # clone() to have xN
|
| 88 |
|
| 89 |
return x.reshape(-1) #x / (x.abs().max() + 1e-7)
|
|
@@ -430,7 +431,6 @@ class LMModel(nn.Module):
|
|
| 430 |
dim = 1536
|
| 431 |
):
|
| 432 |
super().__init__()
|
| 433 |
-
self.cache_lim = -1
|
| 434 |
self.t5 = T5()
|
| 435 |
self.card = card # 2048
|
| 436 |
self.n_draw = 1 # draw > 1 tokens of different CFG scale
|
|
@@ -468,6 +468,7 @@ class LMModel(nn.Module):
|
|
| 468 |
# divide large probs with exp(prob) If prob=.001 then 1/exp(1*.001) -> almost by 0 --> exp doesnt really produce (0, Inf)
|
| 469 |
p = p.argmax(dim=3, keepdim=True) # [bs, 4, n_draw, 24]
|
| 470 |
tok = ix.gather(dim=3, index=p).to(torch.int64) # [bs, 4, n_draw, 1]
|
|
|
|
| 471 |
return tok[:, :, :, 0].transpose(1, 2) # [bs, n_draw, 4]
|
| 472 |
|
| 473 |
@torch.no_grad()
|
|
@@ -537,7 +538,7 @@ class LMModel(nn.Module):
|
|
| 537 |
|
| 538 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
| 539 |
# Sink Attn
|
| 540 |
-
if (offset > 0) and (offset %
|
| 541 |
n_preserve = 4
|
| 542 |
self.transformer._flush(n_preserve=n_preserve)
|
| 543 |
cache_position = n_preserve
|
|
@@ -726,5 +727,5 @@ if __name__ == '__main__':
|
|
| 726 |
|
| 727 |
import audiofile # pip uninstall flash-attn
|
| 728 |
model = AudioGen().to('cpu')
|
| 729 |
-
x = model.generate(prompt='swims in lake frogs', duration=
|
| 730 |
audiofile.write('_sound_.wav', x, 16000)
|
|
|
|
| 62 |
def generate(self,
|
| 63 |
prompt='dogs mewo',
|
| 64 |
duration=2.24, # seconds of audio
|
| 65 |
+
max_tokens=71, # actual num of A/R iterations - above is obtained as clone
|
| 66 |
):
|
| 67 |
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
|
| 68 |
+
n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
|
|
|
|
| 69 |
with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
|
| 70 |
gen_tokens = self.lm.generate(
|
| 71 |
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
|
| 72 |
+
max_tokens=max_tokens)
|
| 73 |
|
| 74 |
# OOM if vocode all tokens
|
| 75 |
x = []
|
|
|
|
| 80 |
decoded_chunk = self.compression_model.decode(gen_tokens[:, :, i-7:i+self._chunk_len])
|
| 81 |
|
| 82 |
x.append(decoded_chunk)
|
| 83 |
+
|
| 84 |
x = torch.cat(x, 2) # [bs, 1, 114000]
|
| 85 |
|
| 86 |
+
x = x.repeat(1, 1, n_draw)
|
| 87 |
+
|
| 88 |
x = _shift(x) # clone() to have xN
|
| 89 |
|
| 90 |
return x.reshape(-1) #x / (x.abs().max() + 1e-7)
|
|
|
|
| 431 |
dim = 1536
|
| 432 |
):
|
| 433 |
super().__init__()
|
|
|
|
| 434 |
self.t5 = T5()
|
| 435 |
self.card = card # 2048
|
| 436 |
self.n_draw = 1 # draw > 1 tokens of different CFG scale
|
|
|
|
| 468 |
# divide large probs with exp(prob) If prob=.001 then 1/exp(1*.001) -> almost by 0 --> exp doesnt really produce (0, Inf)
|
| 469 |
p = p.argmax(dim=3, keepdim=True) # [bs, 4, n_draw, 24]
|
| 470 |
tok = ix.gather(dim=3, index=p).to(torch.int64) # [bs, 4, n_draw, 1]
|
| 471 |
+
|
| 472 |
return tok[:, :, :, 0].transpose(1, 2) # [bs, n_draw, 4]
|
| 473 |
|
| 474 |
@torch.no_grad()
|
|
|
|
| 538 |
|
| 539 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
| 540 |
# Sink Attn
|
| 541 |
+
if (offset > 0) and (offset % 476) == 0:
|
| 542 |
n_preserve = 4
|
| 543 |
self.transformer._flush(n_preserve=n_preserve)
|
| 544 |
cache_position = n_preserve
|
|
|
|
| 727 |
|
| 728 |
import audiofile # pip uninstall flash-attn
|
| 729 |
model = AudioGen().to('cpu')
|
| 730 |
+
x = model.generate(prompt='swims in lake frogs', duration=56.4, max_tokens=24).cpu().numpy()
|
| 731 |
audiofile.write('_sound_.wav', x, 16000)
|