Spaces:
Running
Running
kv flush
Browse files- audiocraft.py +3 -3
audiocraft.py
CHANGED
|
@@ -66,7 +66,7 @@ class AudioGen(torch.nn.Module):
|
|
| 66 |
):
|
| 67 |
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
|
| 68 |
n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
|
| 69 |
-
print(f'{n_draw=} {duration=}seconds < {prompt=}')
|
| 70 |
with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
|
| 71 |
gen_tokens = self.lm.generate(
|
| 72 |
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
|
|
@@ -540,7 +540,7 @@ class LMModel(nn.Module):
|
|
| 540 |
|
| 541 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
| 542 |
# Sink Attn
|
| 543 |
-
if (offset > 0) and (offset %
|
| 544 |
n_preserve = 4
|
| 545 |
self.transformer._flush(n_preserve=n_preserve)
|
| 546 |
cache_position = n_preserve
|
|
@@ -640,7 +640,7 @@ class StreamingMultiheadAttention(nn.Module):
|
|
| 640 |
|
| 641 |
k = self.k_history
|
| 642 |
v = self.v_history
|
| 643 |
-
|
| 644 |
# -> kv CACHE ONLY APPLIES if not self.cross_attention
|
| 645 |
|
| 646 |
x = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
| 66 |
):
|
| 67 |
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
|
| 68 |
n_draw = int(duration * 50 / (max_tokens * N_REPEAT)) + 1
|
| 69 |
+
print(f'{n_draw=} {duration=}seconds < {prompt=} | {max_tokens=}')
|
| 70 |
with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
|
| 71 |
gen_tokens = self.lm.generate(
|
| 72 |
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
|
|
|
|
| 540 |
|
| 541 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
| 542 |
# Sink Attn
|
| 543 |
+
if (offset > 0) and (offset % 71) == 0:
|
| 544 |
n_preserve = 4
|
| 545 |
self.transformer._flush(n_preserve=n_preserve)
|
| 546 |
cache_position = n_preserve
|
|
|
|
| 640 |
|
| 641 |
k = self.k_history
|
| 642 |
v = self.v_history
|
| 643 |
+
|
| 644 |
# -> kv CACHE ONLY APPLIES if not self.cross_attention
|
| 645 |
|
| 646 |
x = torch.nn.functional.scaled_dot_product_attention(
|