Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitattributes +6 -0
- .gitignore +8 -0
- LICENSE +22 -0
- README.md +147 -8
- audio_prompts/EARS p004 freeform.mp3 +3 -0
- audio_prompts/EARS p005 freeform.mp3 +3 -0
- audio_prompts/EARS p028 freeform.mp3 +3 -0
- audio_prompts/EARS p036 freeform.mp3 +3 -0
- audio_prompts/LICENSE +26 -0
- audio_prompts/expresso_02_ex03-ex01_calm_005.mp3 +3 -0
- audio_prompts/freesound_demon_chant(use_forcespeaker).mp3 +3 -0
- autoencoder.py +1225 -0
- gradio_app.py +994 -0
- inference.py +462 -0
- inference_blockwise.py +220 -0
- model.py +642 -0
- requirements.txt +8 -0
- sampler_presets.json +62 -0
- text_presets.txt +40 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
audio_prompts/EARS[[:space:]]p004[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
audio_prompts/EARS[[:space:]]p005[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
audio_prompts/EARS[[:space:]]p028[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
audio_prompts/EARS[[:space:]]p036[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
audio_prompts/expresso_02_ex03-ex01_calm_005.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
audio_prompts/freesound_demon_chant(use_forcespeaker).mp3 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
.venv/
|
| 5 |
+
venv/
|
| 6 |
+
.env
|
| 7 |
+
.idea/
|
| 8 |
+
.vscode/
|
LICENSE
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Jordan Darefsky
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
README.md
CHANGED
|
@@ -1,12 +1,151 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
colorFrom: indigo
|
| 5 |
-
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
---
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: echo-tts
|
| 3 |
+
app_file: gradio_app.py
|
|
|
|
|
|
|
| 4 |
sdk: gradio
|
| 5 |
+
sdk_version: 5.49.1
|
|
|
|
|
|
|
| 6 |
---
|
| 7 |
+
# Echo-TTS
|
| 8 |
|
| 9 |
+
A multi-speaker text-to-speech model with speaker reference conditioning. See the [blog post](https://jordandarefsky.com/blog/2025/echo/) for technical details.
|
| 10 |
+
|
| 11 |
+
**Model:** [jordand/echo-tts-base](https://huggingface.co/jordand/echo-tts-base) | **Demo:** [echo-tts-preview](https://huggingface.co/spaces/jordand/echo-tts-preview)
|
| 12 |
+
|
| 13 |
+
This work was made possible by the TPU Research Cloud (TRC).
|
| 14 |
+
|
| 15 |
+
## Responsible Use
|
| 16 |
+
|
| 17 |
+
Don't use this model to:
|
| 18 |
+
- Impersonate real people without their consent
|
| 19 |
+
- Generate deceptive audio (e.g., fraud, misinformation, deepfakes)
|
| 20 |
+
|
| 21 |
+
You are responsible for complying with local laws regarding biometric data and voice cloning.
|
| 22 |
+
|
| 23 |
+
## Installation
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
pip install -r requirements.txt
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
Requires Python 3.10+ and a CUDA-capable GPU with at least 8GB VRAM.
|
| 30 |
+
|
| 31 |
+
## Quick Start
|
| 32 |
+
|
| 33 |
+
### Gradio UI
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
python gradio_app.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### Python API
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
from inference import (
|
| 43 |
+
load_model_from_hf,
|
| 44 |
+
load_fish_ae_from_hf,
|
| 45 |
+
load_pca_state_from_hf,
|
| 46 |
+
load_audio,
|
| 47 |
+
sample_pipeline,
|
| 48 |
+
sample_euler_cfg_independent_guidances,
|
| 49 |
+
)
|
| 50 |
+
from functools import partial
|
| 51 |
+
import torchaudio
|
| 52 |
+
|
| 53 |
+
# Load models (downloads from HuggingFace on first run)
|
| 54 |
+
model = load_model_from_hf(delete_blockwise_modules=True)
|
| 55 |
+
fish_ae = load_fish_ae_from_hf()
|
| 56 |
+
pca_state = load_pca_state_from_hf()
|
| 57 |
+
|
| 58 |
+
# Load speaker reference (or set to None for no reference)
|
| 59 |
+
speaker_audio = load_audio("speaker.wav").cuda()
|
| 60 |
+
|
| 61 |
+
# Configure sampler
|
| 62 |
+
sample_fn = partial(
|
| 63 |
+
sample_euler_cfg_independent_guidances,
|
| 64 |
+
num_steps=40,
|
| 65 |
+
cfg_scale_text=3.0,
|
| 66 |
+
cfg_scale_speaker=8.0,
|
| 67 |
+
cfg_min_t=0.5,
|
| 68 |
+
cfg_max_t=1.0,
|
| 69 |
+
truncation_factor=None,
|
| 70 |
+
rescale_k=None,
|
| 71 |
+
rescale_sigma=None,
|
| 72 |
+
speaker_kv_scale=None,
|
| 73 |
+
speaker_kv_max_layers=None,
|
| 74 |
+
speaker_kv_min_t=None,
|
| 75 |
+
sequence_length=640, # (~30 seconds)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Generate
|
| 79 |
+
text = "[S1] Hello, this is a test of the Echo TTS model."
|
| 80 |
+
audio_out, _ = sample_pipeline(
|
| 81 |
+
model=model,
|
| 82 |
+
fish_ae=fish_ae,
|
| 83 |
+
pca_state=pca_state,
|
| 84 |
+
sample_fn=sample_fn,
|
| 85 |
+
text_prompt=text,
|
| 86 |
+
speaker_audio=speaker_audio,
|
| 87 |
+
rng_seed=0,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
torchaudio.save("output.wav", audio_out[0].cpu(), 44100)
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
See also:
|
| 94 |
+
- `inference.py` -- lower-level usage example at the bottom of the file
|
| 95 |
+
- `inference_blockwise.py` -- examples of blockwise/continuation generation
|
| 96 |
+
|
| 97 |
+
## Low VRAM (8GB)
|
| 98 |
+
|
| 99 |
+
In `gradio_app.py`, adjust:
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
FISH_AE_DTYPE = torch.bfloat16 # instead of float32
|
| 103 |
+
DEFAULT_SAMPLE_LATENT_LENGTH = 576 # (< 640 depending on what fits) instead of 640
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## Tips
|
| 107 |
+
|
| 108 |
+
### Generation Length
|
| 109 |
+
|
| 110 |
+
Echo is trained to generate up to 30 seconds of audio (640 latents) given text and reference audio. Since the supplied text always corresponded to β€30 seconds of audio during training, the model will attempt to fit any text prompt at inference into the 30 seconds of generated audio (and thus, e.g., long text prompts may result in faster speaking rates). On the other hand, shorter text prompts will work and will produce shorter outputs (as the model generates latent padding automatically).
|
| 111 |
+
|
| 112 |
+
If "Sample Latent Length" (in Custom Shapes in gradio)/sequence_length is set to less than 640, the model will attempt to generate the prefix corresponding to that length. I.e., if you set this to 320, and supply ~30 seconds worth of text, the model will likely generate the first half of the text (rather than try to fit the entirety of the text into the first 15 seconds).
|
| 113 |
+
|
| 114 |
+
### Reference Audio
|
| 115 |
+
|
| 116 |
+
You can condition on up to 5 minutes of reference audio, but shorter clips (e.g., 10 seconds or shorter) work well too.
|
| 117 |
+
|
| 118 |
+
### Force Speaker (KV Scaling)
|
| 119 |
+
|
| 120 |
+
Sometimes out-of-distribution text for a given reference speaker will cause the model to generate a different speaker entirely. Enabling "Force Speaker" (which scales speaker KV for a portion of timesteps, default scale 1.5) generally fixes this. However, high values may introduce artifacts or "overconditioning." Aim for the lowest scale that produces the correct speaker: 1.0 is baseline, 1.5 is the default when enabled and will usually force the speaker, but lower values (e.g., 1.3, 1.1) may suffice.
|
| 121 |
+
|
| 122 |
+
### Text Prompt Format
|
| 123 |
+
|
| 124 |
+
Text prompts use the format from [WhisperD](https://huggingface.co/jordand/whisper-d-v1a). Colons, semicolons, and emdashes are normalized to commas (see inference.py tokenizer_encode) by default, and "[S1] " will be added to the beginning of the prompt if not already present. Commas generally function as pauses. Exclamation points (and other non-bland punctuation) may lead to increased expressiveness but also potentially lower quality on occasion; improving controllability is an important direction for future work.
|
| 125 |
+
|
| 126 |
+
The included text presets are stylistically in-distribution with the WhisperD transcription style.
|
| 127 |
+
|
| 128 |
+
### Blockwise Generation
|
| 129 |
+
|
| 130 |
+
`inference_blockwise.py` includes blockwise sampling, which allows generating audio in smaller blocks as well as producing continuations of existing audio (where the prefix and continuation are up to 30 seconds combined). The model released on HF is a fully fine-tuned model (not the LoRA as described in the blog). Blockwise generation enables audio streaming (not included in current code) since the S1-DAC decoder is causal. Blockwise functionality hasn't been thoroughly tested and may benefit from different (e.g., smaller) CFG scales.
|
| 131 |
+
|
| 132 |
+
## License
|
| 133 |
+
|
| 134 |
+
Code in this repo is MITβlicensed except where file headers specify otherwise (e.g., autoencoder.py is Apacheβ2.0).
|
| 135 |
+
|
| 136 |
+
Regardless of our model license, audio outputs are CC-BY-NC-SA-4.0 due to the dependency on the Fish Speech S1-DAC autoencoder, which is CC-BY-NC-SA-4.0.
|
| 137 |
+
|
| 138 |
+
We have chosen to release the Echo-TTS weights under CC-BY-NC-SA-4.0.
|
| 139 |
+
|
| 140 |
+
For included audio prompts, see `audio_prompts/LICENSE`.
|
| 141 |
+
|
| 142 |
+
## Citation
|
| 143 |
+
|
| 144 |
+
```bibtex
|
| 145 |
+
@misc{darefsky2025echo,
|
| 146 |
+
author = {Darefsky, Jordan},
|
| 147 |
+
title = {Echo-TTS},
|
| 148 |
+
year = {2025},
|
| 149 |
+
url = {https://jordandarefsky.com/blog/2025/echo/}
|
| 150 |
+
}
|
| 151 |
+
```
|
audio_prompts/EARS p004 freeform.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68947a209bc11064f749ca0a61b7959243df83565a0e462b87dfc0ffe03aa7b0
|
| 3 |
+
size 1526439
|
audio_prompts/EARS p005 freeform.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:07344d073eb3e22c249ebfe15f31f4ba63fd9f17c71aeee93da199ff3b53fc45
|
| 3 |
+
size 1351147
|
audio_prompts/EARS p028 freeform.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8351eed5982f1fb5763a475c0fb69dba98a4bb49b0f2bbab12b978ff2b0fedeb
|
| 3 |
+
size 1211565
|
audio_prompts/EARS p036 freeform.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce77dbb86ea7c29edf2b9804ce9c9315334e9cfeef532dc0c50898a09bae1583
|
| 3 |
+
size 1227585
|
audio_prompts/LICENSE
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The audio files in this folder are provided for demonstration purposes and
|
| 2 |
+
are sourced from the following datasets. Please refer to their original
|
| 3 |
+
licenses for terms of use.
|
| 4 |
+
|
| 5 |
+
EARS Dataset (CC-BY-NC-4.0)
|
| 6 |
+
---------------------------
|
| 7 |
+
- EARS p004 freeform.mp3
|
| 8 |
+
- EARS p005 freeform.mp3
|
| 9 |
+
- EARS p028 freeform.mp3
|
| 10 |
+
- EARS p036 freeform.mp3
|
| 11 |
+
|
| 12 |
+
Source: https://github.com/facebookresearch/ears_dataset
|
| 13 |
+
|
| 14 |
+
Expresso Dataset (CC-BY-NC-4.0)
|
| 15 |
+
-------------------------------
|
| 16 |
+
- expresso_02_ex03-ex01_calm_005.wav
|
| 17 |
+
|
| 18 |
+
Source: https://speechbot.github.io/expresso/
|
| 19 |
+
|
| 20 |
+
Freesound (CC0)
|
| 21 |
+
---------------
|
| 22 |
+
- freesound_demon_chant(use_forcespeaker).mp3
|
| 23 |
+
|
| 24 |
+
Source: https://freesound.org/s/419507/
|
| 25 |
+
Author: DylanTheFish
|
| 26 |
+
|
audio_prompts/expresso_02_ex03-ex01_calm_005.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98855b5b8b6c265a643edeb23ce5cd772391cb90754822e2a0370ea5188225f5
|
| 3 |
+
size 4802350
|
audio_prompts/freesound_demon_chant(use_forcespeaker).mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:471f67fff5ea613ec4617b9822b1396da123a1133f199925436a2c40e5d1eb91
|
| 3 |
+
size 303438
|
autoencoder.py
ADDED
|
@@ -0,0 +1,1225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# This file contains portions adapted from:
|
| 4 |
+
# β’ Descript Audio Codec (DAC) β MIT License (full text appended below)
|
| 5 |
+
# β’ Fish-Speech S1 DAC Autoencoder β reference implementation (Apache-2.0 / CC-BY-NC),
|
| 6 |
+
# rewritten here in a single-file Torch module for interoperability and transparency.
|
| 7 |
+
#
|
| 8 |
+
# OVERALL LICENSE (this file): Apache-2.0, except where explicitly marked:
|
| 9 |
+
# # SPDX-License-Identifier: MIT
|
| 10 |
+
# Keep these notices and the embedded MIT text if you redistribute this file.
|
| 11 |
+
|
| 12 |
+
# NOTE
|
| 13 |
+
# Self-contained autoencoder implementation of Fish-S1-DAC (inlining DAC code to avoid dependencies).
|
| 14 |
+
# Code in this module has been largely copy-and-pasted from the Fish-S1-DAC and DAC repositories,
|
| 15 |
+
# and refactored with help from ChatGPT/Claude (these models also helped with licensing).
|
| 16 |
+
# Thus, it differs stylistically from the rest of the codebase (and is likely internally inconsistent as well).
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
from torch import Tensor, nn
|
| 27 |
+
from torch.nn import functional as F
|
| 28 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 29 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
| 30 |
+
|
| 31 |
+
from einops import rearrange
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# --------------------------------------------------------------------
|
| 35 |
+
# Shared helpers
|
| 36 |
+
# --------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
def find_multiple(n: int, k: int) -> int:
|
| 39 |
+
return n if n % k == 0 else n + k - (n % k)
|
| 40 |
+
|
| 41 |
+
def unpad1d(x: Tensor, paddings: Tuple[int, int]) -> Tensor:
|
| 42 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 43 |
+
padding_left, padding_right = paddings
|
| 44 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 45 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 46 |
+
end = x.shape[-1] - padding_right
|
| 47 |
+
return x[..., padding_left:end]
|
| 48 |
+
|
| 49 |
+
def get_extra_padding_for_conv1d(
|
| 50 |
+
x: Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 51 |
+
) -> int:
|
| 52 |
+
"""See pad_for_conv1d; enough right pad so striding evenly covers length."""
|
| 53 |
+
length = x.shape[-1]
|
| 54 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 55 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 56 |
+
return ideal_length - length
|
| 57 |
+
|
| 58 |
+
def pad1d(
|
| 59 |
+
x: Tensor,
|
| 60 |
+
paddings: Tuple[int, int],
|
| 61 |
+
mode: str = "zeros",
|
| 62 |
+
value: float = 0.0,
|
| 63 |
+
) -> Tensor:
|
| 64 |
+
"""
|
| 65 |
+
Reflectβsafe 1D pad: if reflect would underflow on small inputs, insert
|
| 66 |
+
temporary right zero-pad before reflecting.
|
| 67 |
+
"""
|
| 68 |
+
length = x.shape[-1]
|
| 69 |
+
padding_left, padding_right = paddings
|
| 70 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 71 |
+
if mode == "reflect":
|
| 72 |
+
max_pad = max(padding_left, padding_right)
|
| 73 |
+
extra_pad = 0
|
| 74 |
+
if length <= max_pad:
|
| 75 |
+
extra_pad = max_pad - length + 1
|
| 76 |
+
x = F.pad(x, (0, extra_pad))
|
| 77 |
+
padded = F.pad(x, (padding_left, padding_right), mode, value)
|
| 78 |
+
end = padded.shape[-1] - extra_pad
|
| 79 |
+
return padded[..., :end]
|
| 80 |
+
else:
|
| 81 |
+
return F.pad(x, (padding_left, padding_right), mode, value)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# --------------------------------------------------------------------
|
| 85 |
+
# DAC Layers (adapted) β MIT
|
| 86 |
+
# Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/layers.py
|
| 87 |
+
# SPDX-License-Identifier: MIT
|
| 88 |
+
# --------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
def WNConv1d(*args, **kwargs):
|
| 91 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 92 |
+
|
| 93 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 94 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 95 |
+
|
| 96 |
+
@torch.jit.script
|
| 97 |
+
def snake(x: Tensor, alpha: Tensor) -> Tensor:
|
| 98 |
+
shape = x.shape
|
| 99 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 100 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 101 |
+
x = x.reshape(shape)
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
class Snake1d(nn.Module):
|
| 105 |
+
def __init__(self, channels: int):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 108 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 109 |
+
return snake(x, self.alpha)
|
| 110 |
+
|
| 111 |
+
# --------------------------------------------------------------------
|
| 112 |
+
# DAC Vector Quantize (adapted) β MIT
|
| 113 |
+
# Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/quantize.py
|
| 114 |
+
# SPDX-License-Identifier: MIT
|
| 115 |
+
# --------------------------------------------------------------------
|
| 116 |
+
|
| 117 |
+
class VectorQuantize(nn.Module):
|
| 118 |
+
"""
|
| 119 |
+
VQ with factorized, l2-normalized codes (ViTβVQGAN style).
|
| 120 |
+
I/O in (B, D, T).
|
| 121 |
+
"""
|
| 122 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.codebook_size = codebook_size
|
| 125 |
+
self.codebook_dim = codebook_dim
|
| 126 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
| 127 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
| 128 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 129 |
+
|
| 130 |
+
def forward(self, z: Tensor):
|
| 131 |
+
z_e = self.in_proj(z) # (B, D, T)
|
| 132 |
+
z_q, indices = self.decode_latents(z_e)
|
| 133 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 134 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 135 |
+
z_q = z_e + (z_q - z_e).detach() # straightβthrough
|
| 136 |
+
z_q = self.out_proj(z_q)
|
| 137 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 138 |
+
|
| 139 |
+
def embed_code(self, embed_id: Tensor) -> Tensor:
|
| 140 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 141 |
+
|
| 142 |
+
def decode_code(self, embed_id: Tensor) -> Tensor:
|
| 143 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 144 |
+
|
| 145 |
+
def decode_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
|
| 146 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 147 |
+
codebook = self.codebook.weight
|
| 148 |
+
encodings = F.normalize(encodings)
|
| 149 |
+
codebook = F.normalize(codebook)
|
| 150 |
+
dist = (
|
| 151 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 152 |
+
- 2 * encodings @ codebook.t()
|
| 153 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 154 |
+
)
|
| 155 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 156 |
+
z_q = self.decode_code(indices)
|
| 157 |
+
return z_q, indices
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class ResidualVectorQuantize(nn.Module):
|
| 161 |
+
"""SoundStream-style residual VQ stack."""
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
input_dim: int = 512,
|
| 165 |
+
n_codebooks: int = 9,
|
| 166 |
+
codebook_size: int = 1024,
|
| 167 |
+
codebook_dim: Union[int, List[int]] = 8,
|
| 168 |
+
quantizer_dropout: float = 0.0,
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
if isinstance(codebook_dim, int):
|
| 172 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 173 |
+
|
| 174 |
+
self.n_codebooks = n_codebooks
|
| 175 |
+
self.codebook_dim = codebook_dim
|
| 176 |
+
self.codebook_size = codebook_size
|
| 177 |
+
|
| 178 |
+
self.quantizers = nn.ModuleList([
|
| 179 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
| 180 |
+
for i in range(n_codebooks)
|
| 181 |
+
])
|
| 182 |
+
self.quantizer_dropout = quantizer_dropout
|
| 183 |
+
|
| 184 |
+
def forward(self, z: Tensor, n_quantizers: Optional[int] = None):
|
| 185 |
+
z_q = 0
|
| 186 |
+
residual = z
|
| 187 |
+
commitment_loss = 0
|
| 188 |
+
codebook_loss = 0
|
| 189 |
+
|
| 190 |
+
codebook_indices = []
|
| 191 |
+
latents = []
|
| 192 |
+
|
| 193 |
+
if n_quantizers is None:
|
| 194 |
+
n_quantizers = self.n_codebooks
|
| 195 |
+
if self.training:
|
| 196 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 197 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 198 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 199 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 200 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 201 |
+
|
| 202 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 203 |
+
if self.training is False and i >= n_quantizers:
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
z_q_i, commit_i, codebk_i, indices_i, z_e_i = quantizer(residual)
|
| 207 |
+
|
| 208 |
+
mask = (torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers)
|
| 209 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
| 210 |
+
residual = residual - z_q_i
|
| 211 |
+
|
| 212 |
+
commitment_loss += (commit_i * mask).mean()
|
| 213 |
+
codebook_loss += (codebk_i * mask).mean()
|
| 214 |
+
|
| 215 |
+
codebook_indices.append(indices_i)
|
| 216 |
+
latents.append(z_e_i)
|
| 217 |
+
|
| 218 |
+
codes = torch.stack(codebook_indices, dim=1)
|
| 219 |
+
latents = torch.cat(latents, dim=1)
|
| 220 |
+
|
| 221 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
| 222 |
+
|
| 223 |
+
def from_codes(self, codes: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
| 224 |
+
z_q = 0.0
|
| 225 |
+
z_p = []
|
| 226 |
+
n_codebooks = codes.shape[1]
|
| 227 |
+
for i in range(n_codebooks):
|
| 228 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
| 229 |
+
z_p.append(z_p_i)
|
| 230 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 231 |
+
z_q = z_q + z_q_i
|
| 232 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
| 233 |
+
|
| 234 |
+
def from_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
| 235 |
+
z_q = 0
|
| 236 |
+
z_p = []
|
| 237 |
+
codes = []
|
| 238 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 239 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
|
| 240 |
+
for i in range(n_codebooks):
|
| 241 |
+
j, k = dims[i], dims[i + 1]
|
| 242 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 243 |
+
z_p.append(z_p_i)
|
| 244 |
+
codes.append(codes_i)
|
| 245 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 246 |
+
z_q = z_q + z_q_i
|
| 247 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# --------------------------------------------------------------------
|
| 251 |
+
# S1 DAC rvq
|
| 252 |
+
# --------------------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
@dataclass
|
| 255 |
+
class VQResult:
|
| 256 |
+
z: Tensor
|
| 257 |
+
codes: Tensor
|
| 258 |
+
latents: Tensor
|
| 259 |
+
codebook_loss: Tensor
|
| 260 |
+
commitment_loss: Tensor
|
| 261 |
+
semantic_distill_z: Optional[Tensor] = None
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class CausalConvNet(nn.Module):
|
| 265 |
+
def __init__(
|
| 266 |
+
self,
|
| 267 |
+
in_channels,
|
| 268 |
+
out_channels,
|
| 269 |
+
kernel_size,
|
| 270 |
+
dilation=1,
|
| 271 |
+
stride=1,
|
| 272 |
+
groups=1,
|
| 273 |
+
padding=None,
|
| 274 |
+
):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.conv = nn.Conv1d(
|
| 277 |
+
in_channels, out_channels, kernel_size,
|
| 278 |
+
stride=stride, dilation=dilation, groups=groups,
|
| 279 |
+
)
|
| 280 |
+
self.stride = stride
|
| 281 |
+
self.kernel_size = (kernel_size - 1) * dilation + 1
|
| 282 |
+
self.dilation = dilation
|
| 283 |
+
self.padding = self.kernel_size - self.stride
|
| 284 |
+
|
| 285 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 286 |
+
pad = self.padding
|
| 287 |
+
extra = get_extra_padding_for_conv1d(x, self.kernel_size, self.stride, pad)
|
| 288 |
+
x = pad1d(x, (pad, extra), mode="constant", value=0)
|
| 289 |
+
return self.conv(x).contiguous()
|
| 290 |
+
|
| 291 |
+
def weight_norm(self, name="weight", dim=0):
|
| 292 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
| 293 |
+
return self
|
| 294 |
+
|
| 295 |
+
def remove_weight_norm(self):
|
| 296 |
+
self.conv = remove_parametrizations(self.conv)
|
| 297 |
+
return self
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class CausalTransConvNet(nn.Module):
|
| 301 |
+
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None):
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.conv = nn.ConvTranspose1d(
|
| 304 |
+
in_channels, out_channels, kernel_size,
|
| 305 |
+
stride=stride, dilation=dilation
|
| 306 |
+
)
|
| 307 |
+
self.stride = stride
|
| 308 |
+
self.kernel_size = kernel_size
|
| 309 |
+
|
| 310 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 311 |
+
x = self.conv(x)
|
| 312 |
+
pad = self.kernel_size - self.stride
|
| 313 |
+
padding_right = math.ceil(pad)
|
| 314 |
+
padding_left = pad - padding_right
|
| 315 |
+
x = unpad1d(x, (padding_left, padding_right))
|
| 316 |
+
return x.contiguous()
|
| 317 |
+
|
| 318 |
+
def weight_norm(self, name="weight", dim=0):
|
| 319 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
| 320 |
+
return self
|
| 321 |
+
|
| 322 |
+
def remove_weight_norm(self):
|
| 323 |
+
self.conv = remove_parametrizations(self.conv)
|
| 324 |
+
return self
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def CausalWNConv1d(*args, **kwargs):
|
| 328 |
+
return CausalConvNet(*args, **kwargs).weight_norm()
|
| 329 |
+
|
| 330 |
+
def CausalWNConvTranspose1d(*args, **kwargs):
|
| 331 |
+
return CausalTransConvNet(*args, **kwargs).weight_norm()
|
| 332 |
+
|
| 333 |
+
class ConvNeXtBlock(nn.Module):
|
| 334 |
+
r"""ConvNeXt Block (1D).
|
| 335 |
+
DwConv -> (N, C, L) β (N, L, C) -> LN -> Linear -> GELU -> Linear -> (N, C, L) with residual
|
| 336 |
+
"""
|
| 337 |
+
def __init__(
|
| 338 |
+
self,
|
| 339 |
+
dim: int,
|
| 340 |
+
layer_scale_init_value: float = 1e-6,
|
| 341 |
+
mlp_ratio: float = 4.0,
|
| 342 |
+
kernel_size: int = 7,
|
| 343 |
+
dilation: int = 1,
|
| 344 |
+
):
|
| 345 |
+
super().__init__()
|
| 346 |
+
convnet_type = CausalConvNet
|
| 347 |
+
self.dwconv = convnet_type(
|
| 348 |
+
dim, dim, kernel_size=kernel_size,
|
| 349 |
+
groups=dim, dilation=dilation,
|
| 350 |
+
) # depthwise conv
|
| 351 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 352 |
+
self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim))
|
| 353 |
+
self.act = nn.GELU()
|
| 354 |
+
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
| 355 |
+
self.gamma = (
|
| 356 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 357 |
+
if layer_scale_init_value > 0 else None
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
def forward(self, x: Tensor, apply_residual: bool = True) -> Tensor:
|
| 361 |
+
inp = x
|
| 362 |
+
x = self.dwconv(x)
|
| 363 |
+
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
| 364 |
+
x = self.norm(x)
|
| 365 |
+
x = self.pwconv1(x)
|
| 366 |
+
x = self.act(x)
|
| 367 |
+
x = self.pwconv2(x)
|
| 368 |
+
if self.gamma is not None:
|
| 369 |
+
x = self.gamma * x
|
| 370 |
+
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
| 371 |
+
if apply_residual:
|
| 372 |
+
x = inp + x
|
| 373 |
+
return x
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class DownsampleResidualVectorQuantize(nn.Module):
|
| 377 |
+
def __init__(
|
| 378 |
+
self,
|
| 379 |
+
input_dim: int = 1024,
|
| 380 |
+
n_codebooks: int = 9,
|
| 381 |
+
codebook_dim: int = 8,
|
| 382 |
+
quantizer_dropout: float = 0.5,
|
| 383 |
+
codebook_size: int = 1024,
|
| 384 |
+
semantic_codebook_size: int = 4096,
|
| 385 |
+
downsample_factor: Tuple[int, ...] = (2, 2),
|
| 386 |
+
downsample_dims: Optional[Tuple[int, ...]] = None,
|
| 387 |
+
pre_module: Optional[nn.Module] = None,
|
| 388 |
+
post_module: Optional[nn.Module] = None,
|
| 389 |
+
semantic_predictor_module: Optional[nn.Module] = None,
|
| 390 |
+
):
|
| 391 |
+
super().__init__()
|
| 392 |
+
|
| 393 |
+
if downsample_dims is None:
|
| 394 |
+
downsample_dims = tuple(input_dim for _ in range(len(downsample_factor)))
|
| 395 |
+
|
| 396 |
+
all_dims = (input_dim,) + tuple(downsample_dims)
|
| 397 |
+
|
| 398 |
+
self.semantic_quantizer = ResidualVectorQuantize(
|
| 399 |
+
input_dim=input_dim,
|
| 400 |
+
n_codebooks=1,
|
| 401 |
+
codebook_size=semantic_codebook_size,
|
| 402 |
+
codebook_dim=codebook_dim,
|
| 403 |
+
quantizer_dropout=0.0,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
self.quantizer = ResidualVectorQuantize(
|
| 407 |
+
input_dim=input_dim,
|
| 408 |
+
n_codebooks=n_codebooks,
|
| 409 |
+
codebook_size=codebook_size,
|
| 410 |
+
codebook_dim=codebook_dim,
|
| 411 |
+
quantizer_dropout=quantizer_dropout,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
convnet_type = CausalConvNet
|
| 415 |
+
transconvnet_type = CausalTransConvNet
|
| 416 |
+
|
| 417 |
+
self.downsample = nn.Sequential(
|
| 418 |
+
*[
|
| 419 |
+
nn.Sequential(
|
| 420 |
+
convnet_type(all_dims[idx], all_dims[idx + 1], kernel_size=factor, stride=factor),
|
| 421 |
+
ConvNeXtBlock(dim=all_dims[idx + 1]),
|
| 422 |
+
)
|
| 423 |
+
for idx, factor in enumerate(downsample_factor)
|
| 424 |
+
]
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
self.upsample = nn.Sequential(
|
| 428 |
+
*[
|
| 429 |
+
nn.Sequential(
|
| 430 |
+
transconvnet_type(all_dims[idx + 1], all_dims[idx], kernel_size=factor, stride=factor),
|
| 431 |
+
ConvNeXtBlock(dim=all_dims[idx]),
|
| 432 |
+
)
|
| 433 |
+
for idx, factor in reversed(list(enumerate(downsample_factor)))
|
| 434 |
+
]
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
self.apply(self._init_weights)
|
| 438 |
+
self.pre_module = pre_module if pre_module is not None else nn.Identity()
|
| 439 |
+
self.post_module = post_module if post_module is not None else nn.Identity()
|
| 440 |
+
self.semantic_predictor_module = (
|
| 441 |
+
semantic_predictor_module if semantic_predictor_module is not None else nn.Identity()
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
@staticmethod
|
| 445 |
+
def _init_weights(m):
|
| 446 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 447 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 448 |
+
if getattr(m, "bias", None) is not None:
|
| 449 |
+
nn.init.constant_(m.bias, 0)
|
| 450 |
+
|
| 451 |
+
def forward(self, z: Tensor, n_quantizers: Optional[int] = None, semantic_len: Optional[Tensor] = None, **kwargs):
|
| 452 |
+
# z: (B, D, T)
|
| 453 |
+
original_shape = z.shape
|
| 454 |
+
if semantic_len is None:
|
| 455 |
+
semantic_len = torch.LongTensor([z.shape[-1]])
|
| 456 |
+
|
| 457 |
+
z = self.downsample(z)
|
| 458 |
+
z = self.pre_module(z) # (B, D, T) or (B, T, D) depending on module; original uses channels-first in/out
|
| 459 |
+
|
| 460 |
+
semantic_z, semantic_codes, semantic_latents, semantic_commitment_loss, semantic_codebook_loss = \
|
| 461 |
+
self.semantic_quantizer(z)
|
| 462 |
+
residual_z = z - semantic_z
|
| 463 |
+
residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(residual_z, n_quantizers=n_quantizers)
|
| 464 |
+
z = semantic_z + residual_z
|
| 465 |
+
commitment_loss = commitment_loss + semantic_commitment_loss
|
| 466 |
+
codebook_loss = codebook_loss + semantic_codebook_loss
|
| 467 |
+
codes = torch.cat([semantic_codes, codes], dim=1)
|
| 468 |
+
latents = torch.cat([semantic_latents, latents], dim=1)
|
| 469 |
+
z = self.post_module(z)
|
| 470 |
+
z = self.upsample(z)
|
| 471 |
+
|
| 472 |
+
# Pad or crop z to match original shape (time dimension)
|
| 473 |
+
diff = original_shape[-1] - z.shape[-1]
|
| 474 |
+
right = 0
|
| 475 |
+
left = abs(diff) - right
|
| 476 |
+
if diff > 0:
|
| 477 |
+
z = F.pad(z, (left, right))
|
| 478 |
+
elif diff < 0:
|
| 479 |
+
z = z[..., left:]
|
| 480 |
+
|
| 481 |
+
return VQResult(
|
| 482 |
+
z=z, codes=codes, latents=latents,
|
| 483 |
+
commitment_loss=commitment_loss, codebook_loss=codebook_loss,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
def decode(self, indices: Tensor) -> Tensor:
|
| 487 |
+
new_indices = torch.zeros_like(indices)
|
| 488 |
+
new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.semantic_quantizer.codebook_size - 1)
|
| 489 |
+
new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.codebook_size - 1)
|
| 490 |
+
|
| 491 |
+
z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0]
|
| 492 |
+
z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0]
|
| 493 |
+
z_q = z_q_semantic + z_q_residual
|
| 494 |
+
z_q = self.post_module(z_q)
|
| 495 |
+
z_q = self.upsample(z_q)
|
| 496 |
+
return z_q
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
# --------------------------------------------------------------------
|
| 500 |
+
# Transformer stack
|
| 501 |
+
# --------------------------------------------------------------------
|
| 502 |
+
|
| 503 |
+
@dataclass
|
| 504 |
+
class ModelArgs:
|
| 505 |
+
block_size: int = 2048
|
| 506 |
+
n_layer: int = 8
|
| 507 |
+
n_head: int = 8
|
| 508 |
+
dim: int = 512
|
| 509 |
+
intermediate_size: int = 1536
|
| 510 |
+
n_local_heads: int = -1
|
| 511 |
+
head_dim: int = 64
|
| 512 |
+
rope_base: float = 10000
|
| 513 |
+
norm_eps: float = 1e-5
|
| 514 |
+
dropout_rate: float = 0.1
|
| 515 |
+
attn_dropout_rate: float = 0.1
|
| 516 |
+
channels_first: bool = True # to be compatible with conv1d input/output
|
| 517 |
+
pos_embed_type: str = "rope" # "rope" or "conformer"
|
| 518 |
+
max_relative_position: int = 128
|
| 519 |
+
|
| 520 |
+
def __post_init__(self):
|
| 521 |
+
if self.n_local_heads == -1:
|
| 522 |
+
self.n_local_heads = self.n_head
|
| 523 |
+
if self.intermediate_size is None:
|
| 524 |
+
hidden_dim = 4 * self.dim
|
| 525 |
+
n_hidden = int(2 * hidden_dim / 3)
|
| 526 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
| 527 |
+
assert self.pos_embed_type in ["rope", "conformer"]
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class KVCache(nn.Module):
|
| 531 |
+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
|
| 532 |
+
super().__init__()
|
| 533 |
+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
| 534 |
+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
| 535 |
+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
| 536 |
+
|
| 537 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
| 538 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
| 539 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
| 540 |
+
k_out = self.k_cache
|
| 541 |
+
v_out = self.v_cache
|
| 542 |
+
k_out[:, :, input_pos] = k_val
|
| 543 |
+
v_out[:, :, input_pos] = v_val
|
| 544 |
+
return (
|
| 545 |
+
k_out[:, :, : input_pos.max() + 1, :],
|
| 546 |
+
v_out[:, :, : input_pos.max() + 1, :],
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
def clear_cache(self, prompt_len: int):
|
| 550 |
+
self.k_cache[:, :, prompt_len:, :].fill_(0)
|
| 551 |
+
self.v_cache[:, :, prompt_len:, :].fill_(0)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class Transformer(nn.Module):
|
| 555 |
+
def __init__(self, config: ModelArgs) -> None:
|
| 556 |
+
super().__init__()
|
| 557 |
+
self.config = config
|
| 558 |
+
|
| 559 |
+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
| 560 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 561 |
+
|
| 562 |
+
if config.pos_embed_type == "rope":
|
| 563 |
+
freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base)
|
| 564 |
+
self.register_buffer("freqs_cis", freqs_cis)
|
| 565 |
+
else:
|
| 566 |
+
self.register_buffer("freqs_cis", None)
|
| 567 |
+
|
| 568 |
+
causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool))
|
| 569 |
+
self.register_buffer("causal_mask", causal_mask)
|
| 570 |
+
|
| 571 |
+
self.max_batch_size = -1
|
| 572 |
+
self.max_seq_length = -1
|
| 573 |
+
self.use_kv_cache = False
|
| 574 |
+
|
| 575 |
+
def setup_caches(self, max_batch_size, max_seq_length):
|
| 576 |
+
head_dim = self.config.dim // self.config.n_head
|
| 577 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
| 578 |
+
self.max_seq_length = max_seq_length
|
| 579 |
+
self.max_batch_size = max_batch_size
|
| 580 |
+
dtype = self.norm.weight.dtype
|
| 581 |
+
device = self.norm.weight.device
|
| 582 |
+
|
| 583 |
+
for b in self.layers:
|
| 584 |
+
b.attention.kv_cache = KVCache(
|
| 585 |
+
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype
|
| 586 |
+
).to(device)
|
| 587 |
+
|
| 588 |
+
self.use_kv_cache = True
|
| 589 |
+
|
| 590 |
+
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, mask: Optional[Tensor] = None) -> Tensor:
|
| 591 |
+
if self.config.pos_embed_type == "rope":
|
| 592 |
+
assert self.freqs_cis is not None
|
| 593 |
+
freqs_cis = self.freqs_cis[input_pos]
|
| 594 |
+
else:
|
| 595 |
+
freqs_cis = None
|
| 596 |
+
|
| 597 |
+
if mask is None:
|
| 598 |
+
if not self.training and self.use_kv_cache:
|
| 599 |
+
mask = self.causal_mask[None, None, input_pos]
|
| 600 |
+
mask = mask[..., : input_pos.max() + 1]
|
| 601 |
+
else:
|
| 602 |
+
mask = self.causal_mask[None, None, input_pos]
|
| 603 |
+
mask = mask[..., input_pos]
|
| 604 |
+
|
| 605 |
+
for layer in self.layers:
|
| 606 |
+
x = layer(x, input_pos, freqs_cis, mask)
|
| 607 |
+
x = self.norm(x)
|
| 608 |
+
return x
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
class TransformerBlock(nn.Module):
|
| 612 |
+
def __init__(self, config: ModelArgs) -> None:
|
| 613 |
+
super().__init__()
|
| 614 |
+
self.attention = Attention(config)
|
| 615 |
+
self.feed_forward = FeedForward(config)
|
| 616 |
+
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 617 |
+
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 618 |
+
self.attention_layer_scale = LayerScale(config.dim, inplace=True)
|
| 619 |
+
self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
|
| 620 |
+
|
| 621 |
+
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
|
| 622 |
+
h = x + self.attention_layer_scale(
|
| 623 |
+
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
| 624 |
+
)
|
| 625 |
+
out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
|
| 626 |
+
return out
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class Attention(nn.Module):
|
| 630 |
+
def __init__(self, config: ModelArgs):
|
| 631 |
+
super().__init__()
|
| 632 |
+
assert config.dim % config.n_head == 0
|
| 633 |
+
|
| 634 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
| 635 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
| 636 |
+
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
| 637 |
+
self.kv_cache = None
|
| 638 |
+
|
| 639 |
+
self.n_head = config.n_head
|
| 640 |
+
self.head_dim = config.head_dim
|
| 641 |
+
self.n_local_heads = config.n_local_heads
|
| 642 |
+
self.dim = config.dim
|
| 643 |
+
self.attn_dropout_rate = config.attn_dropout_rate
|
| 644 |
+
self.pos_embed_type = config.pos_embed_type
|
| 645 |
+
|
| 646 |
+
if self.pos_embed_type == "conformer":
|
| 647 |
+
self.max_relative_position = config.max_relative_position
|
| 648 |
+
num_pos_embeddings = 2 * config.max_relative_position + 1
|
| 649 |
+
self.rel_pos_embeddings = nn.Parameter(torch.zeros(num_pos_embeddings, self.head_dim))
|
| 650 |
+
nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
|
| 651 |
+
|
| 652 |
+
def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
|
| 653 |
+
positions = torch.arange(seqlen, device=q.device)
|
| 654 |
+
relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
|
| 655 |
+
relative_positions = torch.clamp(relative_positions + self.max_relative_position,
|
| 656 |
+
0, 2 * self.max_relative_position)
|
| 657 |
+
rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
|
| 658 |
+
q = q.transpose(1, 2) # [B, S, H, D]
|
| 659 |
+
rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
|
| 660 |
+
rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
|
| 661 |
+
return rel_logits
|
| 662 |
+
|
| 663 |
+
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
| 664 |
+
bsz, seqlen, _ = x.shape
|
| 665 |
+
|
| 666 |
+
kv_size = self.n_local_heads * self.head_dim
|
| 667 |
+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
| 668 |
+
context_seqlen = seqlen
|
| 669 |
+
|
| 670 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 671 |
+
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
| 672 |
+
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
| 673 |
+
|
| 674 |
+
if self.pos_embed_type == "rope":
|
| 675 |
+
q = apply_rotary_emb(q, freqs_cis)
|
| 676 |
+
k = apply_rotary_emb(k, freqs_cis)
|
| 677 |
+
|
| 678 |
+
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
| 679 |
+
|
| 680 |
+
if self.kv_cache is not None:
|
| 681 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
| 682 |
+
|
| 683 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
| 684 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
| 685 |
+
|
| 686 |
+
if self.pos_embed_type == "conformer":
|
| 687 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
| 688 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
| 689 |
+
rel_scores = self._compute_conformer_pos_scores(q, seqlen)
|
| 690 |
+
scores = scores + rel_scores
|
| 691 |
+
if mask is not None:
|
| 692 |
+
scores = scores.masked_fill(~mask, float("-inf"))
|
| 693 |
+
attn = F.softmax(scores, dim=-1)
|
| 694 |
+
if self.attn_dropout_rate > 0 and self.training:
|
| 695 |
+
attn = F.dropout(attn, p=self.attn_dropout_rate)
|
| 696 |
+
y = torch.matmul(attn, v)
|
| 697 |
+
else:
|
| 698 |
+
y = F.scaled_dot_product_attention(
|
| 699 |
+
q, k, v,
|
| 700 |
+
dropout_p=self.attn_dropout_rate if self.training else 0.0,
|
| 701 |
+
attn_mask=mask,
|
| 702 |
+
)
|
| 703 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
|
| 704 |
+
y = self.wo(y)
|
| 705 |
+
return y
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
class FeedForward(nn.Module):
|
| 709 |
+
def __init__(self, config: ModelArgs) -> None:
|
| 710 |
+
super().__init__()
|
| 711 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
| 712 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
| 713 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
| 714 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 715 |
+
|
| 716 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 717 |
+
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
class RMSNorm(nn.Module):
|
| 721 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 722 |
+
super().__init__()
|
| 723 |
+
self.eps = eps
|
| 724 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 725 |
+
|
| 726 |
+
def _norm(self, x):
|
| 727 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
| 728 |
+
|
| 729 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 730 |
+
output = self._norm(x.float()).type_as(x)
|
| 731 |
+
return output * self.weight
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
class LayerScale(nn.Module):
|
| 735 |
+
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-2, inplace: bool = False) -> None:
|
| 736 |
+
super().__init__()
|
| 737 |
+
self.inplace = inplace
|
| 738 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 739 |
+
|
| 740 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 741 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
class WindowLimitedTransformer(Transformer):
|
| 745 |
+
"""Transformer with window-limited causal attention."""
|
| 746 |
+
def __init__(
|
| 747 |
+
self,
|
| 748 |
+
config: ModelArgs,
|
| 749 |
+
input_dim: int = 512,
|
| 750 |
+
window_size: Optional[int] = None,
|
| 751 |
+
causal: bool = True,
|
| 752 |
+
look_ahead_conv: Optional[nn.Module] = None,
|
| 753 |
+
):
|
| 754 |
+
super().__init__(config)
|
| 755 |
+
self.window_size = window_size
|
| 756 |
+
self.causal = causal
|
| 757 |
+
self.channels_first = config.channels_first
|
| 758 |
+
self.look_ahead_conv = look_ahead_conv if look_ahead_conv is not None else nn.Identity()
|
| 759 |
+
self.input_proj = nn.Linear(input_dim, config.dim) if input_dim != config.dim else nn.Identity()
|
| 760 |
+
self.output_proj = nn.Linear(config.dim, input_dim) if input_dim != config.dim else nn.Identity()
|
| 761 |
+
|
| 762 |
+
def make_window_limited_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor:
|
| 763 |
+
if self.causal:
|
| 764 |
+
mask = torch.tril(torch.ones(max_length, max_length))
|
| 765 |
+
row_indices = torch.arange(max_length).view(-1, 1)
|
| 766 |
+
window_size = self.window_size or max_length
|
| 767 |
+
valid_range = (row_indices - window_size + 1).clamp(min=0)
|
| 768 |
+
column_indices = torch.arange(max_length)
|
| 769 |
+
mask = (column_indices >= valid_range) & mask.bool()
|
| 770 |
+
else:
|
| 771 |
+
raise NotImplementedError
|
| 772 |
+
mask = mask.bool()[None, None]
|
| 773 |
+
return mask
|
| 774 |
+
|
| 775 |
+
def make_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor:
|
| 776 |
+
if self.causal:
|
| 777 |
+
mask = torch.tril(torch.ones(max_length, max_length))
|
| 778 |
+
else:
|
| 779 |
+
mask = torch.ones(max_length, max_length)
|
| 780 |
+
mask = mask.bool()[None, None]
|
| 781 |
+
for i, x_len in enumerate(x_lens):
|
| 782 |
+
mask[:x_len, i] = 0
|
| 783 |
+
mask = mask.bool()[None, None]
|
| 784 |
+
return mask
|
| 785 |
+
|
| 786 |
+
def forward(self, x: Tensor, x_lens: Optional[Tensor] = None) -> Tensor:
|
| 787 |
+
if self.channels_first:
|
| 788 |
+
x = x.transpose(1, 2)
|
| 789 |
+
x = self.input_proj(x)
|
| 790 |
+
x = self.look_ahead_conv(x)
|
| 791 |
+
input_pos = torch.arange(x.shape[1], device=x.device)
|
| 792 |
+
max_length = x.shape[1]
|
| 793 |
+
if self.window_size is not None:
|
| 794 |
+
mask = self.make_window_limited_mask(max_length, x_lens)
|
| 795 |
+
else:
|
| 796 |
+
mask = self.make_mask(max_length, x_lens)
|
| 797 |
+
mask = mask.to(x.device)
|
| 798 |
+
x = super().forward(x, input_pos, mask)
|
| 799 |
+
x = self.output_proj(x)
|
| 800 |
+
if self.channels_first:
|
| 801 |
+
x = x.transpose(1, 2)
|
| 802 |
+
return x
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def precompute_freqs_cis(
|
| 806 |
+
seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
|
| 807 |
+
) -> Tensor:
|
| 808 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
| 809 |
+
t = torch.arange(seq_len, device=freqs.device)
|
| 810 |
+
freqs = torch.outer(t, freqs)
|
| 811 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 812 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
| 813 |
+
return cache.to(dtype=dtype)
|
| 814 |
+
|
| 815 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
| 816 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
| 817 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
| 818 |
+
x_out2 = torch.stack(
|
| 819 |
+
[
|
| 820 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
| 821 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
| 822 |
+
],
|
| 823 |
+
-1,
|
| 824 |
+
)
|
| 825 |
+
x_out2 = x_out2.flatten(3)
|
| 826 |
+
return x_out2.type_as(x)
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def init_weights(m):
|
| 830 |
+
if isinstance(m, nn.Conv1d):
|
| 831 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 832 |
+
nn.init.constant_(m.bias, 0)
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
# --------------------------------------------------------------------
|
| 836 |
+
# Top-level AE
|
| 837 |
+
# --------------------------------------------------------------------
|
| 838 |
+
|
| 839 |
+
class EncoderBlock(nn.Module):
|
| 840 |
+
def __init__(
|
| 841 |
+
self,
|
| 842 |
+
dim: int = 16,
|
| 843 |
+
stride: int = 1,
|
| 844 |
+
causal: bool = False,
|
| 845 |
+
n_t_layer: int = 0,
|
| 846 |
+
transformer_general_config=None,
|
| 847 |
+
):
|
| 848 |
+
super().__init__()
|
| 849 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
| 850 |
+
transformer_module = (
|
| 851 |
+
nn.Identity()
|
| 852 |
+
if n_t_layer == 0
|
| 853 |
+
else WindowLimitedTransformer(
|
| 854 |
+
causal=causal,
|
| 855 |
+
input_dim=dim,
|
| 856 |
+
window_size=512,
|
| 857 |
+
config=transformer_general_config(
|
| 858 |
+
n_layer=n_t_layer,
|
| 859 |
+
n_head=dim // 64,
|
| 860 |
+
dim=dim,
|
| 861 |
+
intermediate_size=dim * 3,
|
| 862 |
+
),
|
| 863 |
+
)
|
| 864 |
+
)
|
| 865 |
+
self.block = nn.Sequential(
|
| 866 |
+
# three multiβreceptiveβfield residual units
|
| 867 |
+
ResidualUnit(dim // 2, dilation=1, causal=causal),
|
| 868 |
+
ResidualUnit(dim // 2, dilation=3, causal=causal),
|
| 869 |
+
ResidualUnit(dim // 2, dilation=9, causal=causal),
|
| 870 |
+
Snake1d(dim // 2),
|
| 871 |
+
conv_class(dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)),
|
| 872 |
+
transformer_module,
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 876 |
+
return self.block(x)
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
class ResidualUnit(nn.Module):
|
| 880 |
+
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
|
| 881 |
+
super().__init__()
|
| 882 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
| 883 |
+
pad = ((7 - 1) * dilation) // 2
|
| 884 |
+
self.block = nn.Sequential(
|
| 885 |
+
Snake1d(dim),
|
| 886 |
+
conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
| 887 |
+
Snake1d(dim),
|
| 888 |
+
conv_class(dim, dim, kernel_size=1),
|
| 889 |
+
)
|
| 890 |
+
self.causal = causal
|
| 891 |
+
|
| 892 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 893 |
+
y = self.block(x)
|
| 894 |
+
pad = x.shape[-1] - y.shape[-1]
|
| 895 |
+
if pad > 0:
|
| 896 |
+
if self.causal:
|
| 897 |
+
x = x[..., :-pad]
|
| 898 |
+
else:
|
| 899 |
+
x = x[..., pad // 2 : -pad // 2]
|
| 900 |
+
return x + y
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
class Encoder(nn.Module):
|
| 904 |
+
def __init__(
|
| 905 |
+
self,
|
| 906 |
+
d_model: int = 64,
|
| 907 |
+
strides: List[int] = [2, 4, 8, 8],
|
| 908 |
+
d_latent: int = 64,
|
| 909 |
+
n_transformer_layers: List[int] = [0, 0, 4, 4],
|
| 910 |
+
transformer_general_config: Optional[ModelArgs] = None,
|
| 911 |
+
causal: bool = False,
|
| 912 |
+
):
|
| 913 |
+
super().__init__()
|
| 914 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
| 915 |
+
layers: List[nn.Module] = [conv_class(1, d_model, kernel_size=7, padding=3)]
|
| 916 |
+
for stride, n_t_layer in zip(strides, n_transformer_layers):
|
| 917 |
+
d_model *= 2
|
| 918 |
+
layers.append(
|
| 919 |
+
EncoderBlock(
|
| 920 |
+
d_model, stride=stride, causal=causal,
|
| 921 |
+
n_t_layer=n_t_layer, transformer_general_config=transformer_general_config,
|
| 922 |
+
)
|
| 923 |
+
)
|
| 924 |
+
layers += [Snake1d(d_model), conv_class(d_model, d_latent, kernel_size=3, padding=1)]
|
| 925 |
+
self.block = nn.Sequential(*layers)
|
| 926 |
+
self.enc_dim = d_model
|
| 927 |
+
|
| 928 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 929 |
+
return self.block(x)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
class DecoderBlock(nn.Module):
|
| 933 |
+
def __init__(
|
| 934 |
+
self,
|
| 935 |
+
input_dim: int = 16,
|
| 936 |
+
output_dim: int = 8,
|
| 937 |
+
stride: int = 1,
|
| 938 |
+
causal: bool = False,
|
| 939 |
+
n_t_layer: int = 0,
|
| 940 |
+
transformer_general_config=None,
|
| 941 |
+
):
|
| 942 |
+
super().__init__()
|
| 943 |
+
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
|
| 944 |
+
transformer_module = (
|
| 945 |
+
nn.Identity()
|
| 946 |
+
if n_t_layer == 0
|
| 947 |
+
else WindowLimitedTransformer(
|
| 948 |
+
causal=causal,
|
| 949 |
+
input_dim=input_dim,
|
| 950 |
+
window_size=None,
|
| 951 |
+
config=transformer_general_config(
|
| 952 |
+
n_layer=n_t_layer,
|
| 953 |
+
n_head=input_dim // 64,
|
| 954 |
+
dim=input_dim,
|
| 955 |
+
intermediate_size=input_dim * 3,
|
| 956 |
+
),
|
| 957 |
+
)
|
| 958 |
+
)
|
| 959 |
+
self.block = nn.Sequential(
|
| 960 |
+
Snake1d(input_dim),
|
| 961 |
+
conv_trans_class(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)),
|
| 962 |
+
ResidualUnit(output_dim, dilation=1, causal=causal),
|
| 963 |
+
ResidualUnit(output_dim, dilation=3, causal=causal),
|
| 964 |
+
ResidualUnit(output_dim, dilation=9, causal=causal),
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 968 |
+
return self.block(x)
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class Decoder(nn.Module):
|
| 972 |
+
def __init__(
|
| 973 |
+
self,
|
| 974 |
+
input_channel: int,
|
| 975 |
+
channels: int,
|
| 976 |
+
rates: List[int],
|
| 977 |
+
d_out: int = 1,
|
| 978 |
+
causal: bool = False,
|
| 979 |
+
n_transformer_layers: List[int] = [0, 0, 0, 0],
|
| 980 |
+
transformer_general_config=None,
|
| 981 |
+
):
|
| 982 |
+
super().__init__()
|
| 983 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
| 984 |
+
layers: List[nn.Module] = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
|
| 985 |
+
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
|
| 986 |
+
input_dim = channels // 2**i
|
| 987 |
+
output_dim = channels // 2 ** (i + 1)
|
| 988 |
+
layers.append(
|
| 989 |
+
DecoderBlock(
|
| 990 |
+
input_dim, output_dim, stride, causal=causal,
|
| 991 |
+
n_t_layer=n_t_layer, transformer_general_config=transformer_general_config,
|
| 992 |
+
)
|
| 993 |
+
)
|
| 994 |
+
layers += [Snake1d(output_dim), conv_class(output_dim, d_out, kernel_size=7, padding=3), nn.Tanh()]
|
| 995 |
+
self.model = nn.Sequential(*layers)
|
| 996 |
+
|
| 997 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 998 |
+
return self.model(x)
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
class DAC(nn.Module):
|
| 1002 |
+
def __init__(
|
| 1003 |
+
self,
|
| 1004 |
+
encoder_dim: int = 64,
|
| 1005 |
+
encoder_rates: List[int] = [2, 4, 8, 8],
|
| 1006 |
+
latent_dim: Optional[int] = None,
|
| 1007 |
+
decoder_dim: int = 1536,
|
| 1008 |
+
decoder_rates: List[int] = [8, 8, 4, 2],
|
| 1009 |
+
quantizer: Optional[nn.Module] = None,
|
| 1010 |
+
sample_rate: int = 44100,
|
| 1011 |
+
causal: bool = True,
|
| 1012 |
+
encoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
| 1013 |
+
decoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
| 1014 |
+
transformer_general_config=None,
|
| 1015 |
+
):
|
| 1016 |
+
super().__init__()
|
| 1017 |
+
|
| 1018 |
+
self.encoder_dim = encoder_dim
|
| 1019 |
+
self.encoder_rates = encoder_rates
|
| 1020 |
+
self.decoder_dim = decoder_dim
|
| 1021 |
+
self.decoder_rates = decoder_rates
|
| 1022 |
+
self.sample_rate = sample_rate
|
| 1023 |
+
|
| 1024 |
+
if latent_dim is None:
|
| 1025 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
| 1026 |
+
self.latent_dim = latent_dim
|
| 1027 |
+
|
| 1028 |
+
self.hop_length = int(np.prod(encoder_rates))
|
| 1029 |
+
self.encoder = Encoder(
|
| 1030 |
+
encoder_dim, encoder_rates, latent_dim, causal=causal,
|
| 1031 |
+
n_transformer_layers=encoder_transformer_layers,
|
| 1032 |
+
transformer_general_config=transformer_general_config,
|
| 1033 |
+
)
|
| 1034 |
+
self.quantizer = quantizer
|
| 1035 |
+
self.decoder = Decoder(
|
| 1036 |
+
latent_dim, decoder_dim, decoder_rates, causal=causal,
|
| 1037 |
+
n_transformer_layers=decoder_transformer_layers,
|
| 1038 |
+
transformer_general_config=transformer_general_config,
|
| 1039 |
+
)
|
| 1040 |
+
self.sample_rate = sample_rate
|
| 1041 |
+
self.apply(init_weights)
|
| 1042 |
+
|
| 1043 |
+
self.delay = self.get_delay()
|
| 1044 |
+
self.frame_length = self.hop_length * 4
|
| 1045 |
+
|
| 1046 |
+
def get_output_length(self, input_length: int) -> int:
|
| 1047 |
+
length = input_length
|
| 1048 |
+
for stride in self.encoder_rates:
|
| 1049 |
+
length = math.ceil(length / stride)
|
| 1050 |
+
return length
|
| 1051 |
+
|
| 1052 |
+
def get_delay(self) -> int:
|
| 1053 |
+
l_out = self.get_output_length(0)
|
| 1054 |
+
L = l_out
|
| 1055 |
+
|
| 1056 |
+
layers = [layer for layer in self.modules() if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d))]
|
| 1057 |
+
for layer in reversed(layers):
|
| 1058 |
+
d = layer.dilation[0]
|
| 1059 |
+
k = layer.kernel_size[0]
|
| 1060 |
+
s = layer.stride[0]
|
| 1061 |
+
if isinstance(layer, nn.ConvTranspose1d):
|
| 1062 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
| 1063 |
+
elif isinstance(layer, nn.Conv1d):
|
| 1064 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
| 1065 |
+
L = math.ceil(L)
|
| 1066 |
+
|
| 1067 |
+
l_in = L
|
| 1068 |
+
return (l_in - l_out) // 2
|
| 1069 |
+
|
| 1070 |
+
def preprocess(self, audio_data: Tensor, sample_rate: Optional[int]) -> Tensor:
|
| 1071 |
+
if sample_rate is None:
|
| 1072 |
+
sample_rate = self.sample_rate
|
| 1073 |
+
assert sample_rate == self.sample_rate
|
| 1074 |
+
|
| 1075 |
+
length = audio_data.shape[-1]
|
| 1076 |
+
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
| 1077 |
+
audio_data = F.pad(audio_data, (0, right_pad))
|
| 1078 |
+
return audio_data
|
| 1079 |
+
|
| 1080 |
+
def encode(
|
| 1081 |
+
self,
|
| 1082 |
+
audio_data: Tensor,
|
| 1083 |
+
audio_lengths: Optional[Tensor] = None,
|
| 1084 |
+
n_quantizers: Optional[int] = None,
|
| 1085 |
+
**kwargs,
|
| 1086 |
+
):
|
| 1087 |
+
"""Encode audio to quantized code indices."""
|
| 1088 |
+
if audio_data.ndim == 2:
|
| 1089 |
+
audio_data = audio_data.unsqueeze(1)
|
| 1090 |
+
length = audio_data.shape[-1]
|
| 1091 |
+
right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
|
| 1092 |
+
audio_data = F.pad(audio_data, (0, right_pad))
|
| 1093 |
+
if audio_lengths is None:
|
| 1094 |
+
audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
|
| 1095 |
+
|
| 1096 |
+
z = self.encoder(audio_data)
|
| 1097 |
+
vq_results = self.quantizer(z, n_quantizers, **kwargs)
|
| 1098 |
+
indices = vq_results.codes
|
| 1099 |
+
indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
|
| 1100 |
+
return indices, indices_lens
|
| 1101 |
+
|
| 1102 |
+
def decode(self, indices: Tensor, feature_lengths: Tensor):
|
| 1103 |
+
"""Decode code indices to audio."""
|
| 1104 |
+
if indices.ndim == 2:
|
| 1105 |
+
indices = indices[None]
|
| 1106 |
+
z = self.quantizer.decode(indices)
|
| 1107 |
+
audio_lengths = feature_lengths * self.frame_length
|
| 1108 |
+
return self.decoder(z), audio_lengths
|
| 1109 |
+
|
| 1110 |
+
def encode_to_codes(self, audio: Tensor, audio_lengths: Optional[Tensor] = None, n_quantizers: Optional[int] = None, **kw):
|
| 1111 |
+
return self.encode(audio, audio_lengths, n_quantizers, **kw)
|
| 1112 |
+
|
| 1113 |
+
def decode_codes(self, indices: Tensor, feature_lengths: Tensor):
|
| 1114 |
+
return self.decode(indices, feature_lengths)
|
| 1115 |
+
|
| 1116 |
+
@torch.no_grad()
|
| 1117 |
+
def encode_zq(self, audio_data: Tensor) -> Tensor:
|
| 1118 |
+
indices, _ = self.encode(audio_data)
|
| 1119 |
+
new_indices = torch.zeros_like(indices)
|
| 1120 |
+
new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.quantizer.semantic_quantizer.codebook_size - 1)
|
| 1121 |
+
new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.quantizer.codebook_size - 1)
|
| 1122 |
+
|
| 1123 |
+
z_q_semantic = self.quantizer.semantic_quantizer.from_codes(new_indices[:, :1])[0]
|
| 1124 |
+
z_q_residual = self.quantizer.quantizer.from_codes(new_indices[:, 1:])[0]
|
| 1125 |
+
z_q = z_q_semantic + z_q_residual
|
| 1126 |
+
return z_q
|
| 1127 |
+
|
| 1128 |
+
@torch.no_grad()
|
| 1129 |
+
def decode_zq(self, z_q: Tensor) -> Tensor:
|
| 1130 |
+
z_q = self.quantizer.post_module(z_q)
|
| 1131 |
+
z_q = self.quantizer.upsample(z_q)
|
| 1132 |
+
return self.decoder(z_q)
|
| 1133 |
+
|
| 1134 |
+
@property
|
| 1135 |
+
def device(self) -> torch.device: return next(self.parameters()).device
|
| 1136 |
+
|
| 1137 |
+
@property
|
| 1138 |
+
def dtype(self) -> torch.dtype: return next(self.parameters()).dtype
|
| 1139 |
+
|
| 1140 |
+
# --------------------------------------------------------------------
|
| 1141 |
+
# Build helpers
|
| 1142 |
+
# --------------------------------------------------------------------
|
| 1143 |
+
|
| 1144 |
+
def build_ae(**cfg) -> DAC:
|
| 1145 |
+
"""
|
| 1146 |
+
Factory used by external loaders
|
| 1147 |
+
"""
|
| 1148 |
+
# Shared transformer config for the RVQ pre/post modules
|
| 1149 |
+
q_config = ModelArgs(
|
| 1150 |
+
block_size=4096, n_layer=8, n_head=16, dim=1024,
|
| 1151 |
+
intermediate_size=3072, head_dim=64, norm_eps=1e-5,
|
| 1152 |
+
dropout_rate=0.1, attn_dropout_rate=0.1, channels_first=True
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
def make_transformer():
|
| 1156 |
+
return WindowLimitedTransformer(
|
| 1157 |
+
causal=True, window_size=128, input_dim=1024, config=q_config
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
quantizer = DownsampleResidualVectorQuantize(
|
| 1161 |
+
input_dim=1024, n_codebooks=9, codebook_size=1024, codebook_dim=8,
|
| 1162 |
+
quantizer_dropout=0.5, downsample_factor=(2, 2),
|
| 1163 |
+
semantic_codebook_size=4096,
|
| 1164 |
+
pre_module=make_transformer(),
|
| 1165 |
+
post_module=make_transformer(),
|
| 1166 |
+
)
|
| 1167 |
+
|
| 1168 |
+
def transformer_general_config(**kw):
|
| 1169 |
+
return ModelArgs(
|
| 1170 |
+
block_size=kw.get("block_size", 16384),
|
| 1171 |
+
n_layer=kw.get("n_layer", 8),
|
| 1172 |
+
n_head=kw.get("n_head", 8),
|
| 1173 |
+
dim=kw.get("dim", 512),
|
| 1174 |
+
intermediate_size=kw.get("intermediate_size", 1536),
|
| 1175 |
+
n_local_heads=kw.get("n_local_heads", -1),
|
| 1176 |
+
head_dim=kw.get("head_dim", 64),
|
| 1177 |
+
rope_base=kw.get("rope_base", 10000),
|
| 1178 |
+
norm_eps=kw.get("norm_eps", 1e-5),
|
| 1179 |
+
dropout_rate=kw.get("dropout_rate", 0.1),
|
| 1180 |
+
attn_dropout_rate=kw.get("attn_dropout_rate", 0.1),
|
| 1181 |
+
channels_first=kw.get("channels_first", True),
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
dac = DAC(
|
| 1185 |
+
encoder_dim=64, encoder_rates=[2, 4, 8, 8], latent_dim=1024,
|
| 1186 |
+
decoder_dim=1536, decoder_rates=[8, 8, 4, 2],
|
| 1187 |
+
quantizer=quantizer, sample_rate=44100, causal=True,
|
| 1188 |
+
encoder_transformer_layers=[0, 0, 0, 4],
|
| 1189 |
+
decoder_transformer_layers=[4, 0, 0, 0],
|
| 1190 |
+
transformer_general_config=transformer_general_config,
|
| 1191 |
+
)
|
| 1192 |
+
return dac
|
| 1193 |
+
|
| 1194 |
+
__all__ = [
|
| 1195 |
+
"DAC",
|
| 1196 |
+
"build_ae",
|
| 1197 |
+
"VectorQuantize",
|
| 1198 |
+
"ResidualVectorQuantize",
|
| 1199 |
+
"DownsampleResidualVectorQuantize",
|
| 1200 |
+
]
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
# ----- BEGIN DAC MIT LICENSE -----
|
| 1204 |
+
# MIT License
|
| 1205 |
+
# Copyright (c) 2023-present, Descript
|
| 1206 |
+
#
|
| 1207 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 1208 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 1209 |
+
# in the Software without restriction, including without limitation the rights
|
| 1210 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 1211 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 1212 |
+
# furnished to do so, subject to the following conditions:
|
| 1213 |
+
#
|
| 1214 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 1215 |
+
# copies or substantial portions of the Software.
|
| 1216 |
+
#
|
| 1217 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 1218 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 1219 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 1220 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 1221 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 1222 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 1223 |
+
# SOFTWARE.
|
| 1224 |
+
# ----- END DAC MIT LICENSE -----
|
| 1225 |
+
|
gradio_app.py
ADDED
|
@@ -0,0 +1,994 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import secrets
|
| 6 |
+
import logging
|
| 7 |
+
import warnings
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Tuple, Any
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
# see lines ~40-50 for running on lower VRAM GPUs
|
| 13 |
+
|
| 14 |
+
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
| 15 |
+
|
| 16 |
+
import gradio as gr
|
| 17 |
+
import torch
|
| 18 |
+
import torchaudio
|
| 19 |
+
|
| 20 |
+
from inference import (
|
| 21 |
+
load_model_from_hf,
|
| 22 |
+
load_fish_ae_from_hf,
|
| 23 |
+
load_pca_state_from_hf,
|
| 24 |
+
load_audio,
|
| 25 |
+
ae_reconstruct,
|
| 26 |
+
sample_pipeline,
|
| 27 |
+
compile_model,
|
| 28 |
+
compile_fish_ae,
|
| 29 |
+
sample_euler_cfg_independent_guidances
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# --------------------------------------------------------------------
|
| 33 |
+
# IF ON 8GB VRAM GPU, SET FISH_AE_DTYPE to bfloat16 and DEFAULT_SAMPLE_LATENT_LENGTH to < 640 (e.g., 576)
|
| 34 |
+
|
| 35 |
+
# Configuration
|
| 36 |
+
MODEL_DTYPE = torch.bfloat16
|
| 37 |
+
FISH_AE_DTYPE = torch.float32
|
| 38 |
+
# FISH_AE_DTYPE = torch.bfloat16 # USE THIS IF OOM ON 8GB vram GPU
|
| 39 |
+
|
| 40 |
+
DEFAULT_SAMPLE_LATENT_LENGTH = 640 # decrease if OOM on 8GB vram GPU
|
| 41 |
+
# DEFAULT_SAMPLE_LATENT_LENGTH = 576 # (example, ~27 seconds rather than ~30; can change depending on what fits in VRAM)
|
| 42 |
+
|
| 43 |
+
# NOTE peak S1-DAC decoding VRAM > peak latent sampling VRAM, so decoding in chunks (which is posisble as S1-DAC is causal) would allow for full 640-length generation on lower VRAM GPUs
|
| 44 |
+
|
| 45 |
+
# --------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
# Audio Prompt Library for Custom Audio Panel (included in repo)
|
| 48 |
+
AUDIO_PROMPT_FOLDER = Path("./audio_prompts")
|
| 49 |
+
|
| 50 |
+
# --------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
TEXT_PRESETS_PATH = Path("./text_presets.txt")
|
| 53 |
+
SAMPLER_PRESETS_PATH = Path("./sampler_presets.json")
|
| 54 |
+
|
| 55 |
+
TEMP_AUDIO_DIR = Path("./temp_gradio_audio")
|
| 56 |
+
TEMP_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
# --------------------------------------------------------------------
|
| 59 |
+
# Model loading (eager for local use)
|
| 60 |
+
model = load_model_from_hf(dtype=MODEL_DTYPE, delete_blockwise_modules=True)
|
| 61 |
+
fish_ae = load_fish_ae_from_hf(dtype=FISH_AE_DTYPE)
|
| 62 |
+
pca_state = load_pca_state_from_hf()
|
| 63 |
+
|
| 64 |
+
model_compiled = None
|
| 65 |
+
fish_ae_compiled = None
|
| 66 |
+
|
| 67 |
+
# --------------------------------------------------------------------
|
| 68 |
+
# Helper functions
|
| 69 |
+
def make_stem(prefix: str, user_id: str | None = None) -> str:
|
| 70 |
+
"""Create unique filename stem: prefix__user__timestamp_random or prefix__timestamp_random if no user_id."""
|
| 71 |
+
ts = int(time.time() * 1000)
|
| 72 |
+
rand = secrets.token_hex(4)
|
| 73 |
+
if user_id:
|
| 74 |
+
return f"{prefix}__{user_id}__{ts}_{rand}"
|
| 75 |
+
return f"{prefix}__{ts}_{rand}"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def cleanup_temp_audio(dir_: Path, user_id: str | None, max_age_sec: int = 60 * 5):
|
| 79 |
+
"""Remove old files globally and all previous files for this user."""
|
| 80 |
+
now = time.time()
|
| 81 |
+
|
| 82 |
+
for p in dir_.glob("*"):
|
| 83 |
+
try:
|
| 84 |
+
if p.is_file() and (now - p.stat().st_mtime) > max_age_sec:
|
| 85 |
+
p.unlink(missing_ok=True)
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
if user_id:
|
| 90 |
+
for p in dir_.glob(f"*__{user_id}__*"):
|
| 91 |
+
try:
|
| 92 |
+
if p.is_file():
|
| 93 |
+
p.unlink(missing_ok=True)
|
| 94 |
+
except Exception:
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def save_audio_with_format(audio_tensor: torch.Tensor, base_path: Path, filename: str, sample_rate: int, audio_format: str) -> Path:
|
| 99 |
+
"""Save audio in specified format, fallback to WAV if MP3 encoding fails."""
|
| 100 |
+
if audio_format == "mp3":
|
| 101 |
+
try:
|
| 102 |
+
output_path = base_path / f"{filename}.mp3"
|
| 103 |
+
torchaudio.save(
|
| 104 |
+
str(output_path),
|
| 105 |
+
audio_tensor,
|
| 106 |
+
sample_rate,
|
| 107 |
+
format="mp3",
|
| 108 |
+
encoding="mp3",
|
| 109 |
+
bits_per_sample=None,
|
| 110 |
+
)
|
| 111 |
+
return output_path
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"MP3 encoding failed: {e}, falling back to WAV")
|
| 114 |
+
output_path = base_path / f"{filename}.wav"
|
| 115 |
+
torchaudio.save(str(output_path), audio_tensor, sample_rate)
|
| 116 |
+
return output_path
|
| 117 |
+
|
| 118 |
+
output_path = base_path / f"{filename}.wav"
|
| 119 |
+
torchaudio.save(str(output_path), audio_tensor, sample_rate)
|
| 120 |
+
return output_path
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def to_bool(val: Any) -> bool:
|
| 124 |
+
"""Parse truthy values from common string/bool inputs."""
|
| 125 |
+
return str(val).strip().lower() not in {"", "0", "false", "off", "none", "no"}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def find_min_bucket_gte(values_str: str, actual_length: int) -> int | None:
|
| 129 |
+
"""Parse comma-separated values and find minimum value >= actual_length.
|
| 130 |
+
|
| 131 |
+
If a single value is provided (no comma), returns that value directly.
|
| 132 |
+
If comma-separated, finds the smallest bucket that can fit the content.
|
| 133 |
+
Returns None if empty string.
|
| 134 |
+
"""
|
| 135 |
+
if not values_str or not values_str.strip():
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
values_str = values_str.strip()
|
| 139 |
+
|
| 140 |
+
# Single value case - return as-is
|
| 141 |
+
if "," not in values_str:
|
| 142 |
+
return int(values_str)
|
| 143 |
+
|
| 144 |
+
# Multiple values - find minimum >= actual_length
|
| 145 |
+
values = [int(v.strip()) for v in values_str.split(",") if v.strip()]
|
| 146 |
+
if not values:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
# Find minimum value >= actual_length
|
| 150 |
+
candidates = [v for v in values if v >= actual_length]
|
| 151 |
+
if candidates:
|
| 152 |
+
return min(candidates)
|
| 153 |
+
|
| 154 |
+
# If no value is >=, return the maximum (best effort)
|
| 155 |
+
return max(values)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def generate_audio(
|
| 159 |
+
text_prompt: str,
|
| 160 |
+
speaker_audio_path: str,
|
| 161 |
+
num_steps: int,
|
| 162 |
+
rng_seed: int,
|
| 163 |
+
cfg_scale_text: float,
|
| 164 |
+
cfg_scale_speaker: float,
|
| 165 |
+
cfg_min_t: float,
|
| 166 |
+
cfg_max_t: float,
|
| 167 |
+
truncation_factor: float,
|
| 168 |
+
rescale_k: float,
|
| 169 |
+
rescale_sigma: float,
|
| 170 |
+
force_speaker: bool,
|
| 171 |
+
speaker_kv_scale: float,
|
| 172 |
+
speaker_kv_min_t: float,
|
| 173 |
+
speaker_kv_max_layers: int,
|
| 174 |
+
reconstruct_first_30_seconds: bool,
|
| 175 |
+
use_custom_shapes: bool,
|
| 176 |
+
max_text_byte_length: str,
|
| 177 |
+
max_speaker_latent_length: str,
|
| 178 |
+
sample_latent_length: str,
|
| 179 |
+
audio_format: str,
|
| 180 |
+
use_compile: bool,
|
| 181 |
+
show_original_audio: bool,
|
| 182 |
+
session_id: str,
|
| 183 |
+
) -> Tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any]:
|
| 184 |
+
"""Generate audio using the model."""
|
| 185 |
+
global model_compiled, fish_ae_compiled
|
| 186 |
+
|
| 187 |
+
if use_compile:
|
| 188 |
+
if model_compiled is None:
|
| 189 |
+
try:
|
| 190 |
+
model_compiled = compile_model(model)
|
| 191 |
+
fish_ae_compiled = compile_fish_ae(fish_ae)
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f"Compilation wrapping failed: {str(e)}")
|
| 194 |
+
model_compiled = None
|
| 195 |
+
fish_ae_compiled = None
|
| 196 |
+
use_compile = False
|
| 197 |
+
|
| 198 |
+
active_model = model_compiled if use_compile else model
|
| 199 |
+
active_fish_ae = fish_ae_compiled if use_compile else fish_ae
|
| 200 |
+
|
| 201 |
+
cleanup_temp_audio(TEMP_AUDIO_DIR, session_id)
|
| 202 |
+
|
| 203 |
+
start_time = time.time()
|
| 204 |
+
|
| 205 |
+
num_steps_int = min(max(int(num_steps), 1), 80)
|
| 206 |
+
rng_seed_int = int(rng_seed) if rng_seed is not None else 0
|
| 207 |
+
cfg_scale_text_val = float(cfg_scale_text)
|
| 208 |
+
cfg_scale_speaker_val = float(cfg_scale_speaker) if cfg_scale_speaker is not None else None
|
| 209 |
+
cfg_min_t_val = float(cfg_min_t)
|
| 210 |
+
cfg_max_t_val = float(cfg_max_t)
|
| 211 |
+
truncation_factor_val = float(truncation_factor)
|
| 212 |
+
rescale_k_val = float(rescale_k) if rescale_k != 1.0 else None
|
| 213 |
+
rescale_sigma_val = float(rescale_sigma)
|
| 214 |
+
|
| 215 |
+
speaker_kv_enabled = bool(force_speaker)
|
| 216 |
+
if speaker_kv_enabled:
|
| 217 |
+
speaker_kv_scale_val = float(speaker_kv_scale) if speaker_kv_scale is not None else None
|
| 218 |
+
speaker_kv_min_t_val = float(speaker_kv_min_t) if speaker_kv_min_t is not None else None
|
| 219 |
+
speaker_kv_max_layers_val = int(speaker_kv_max_layers) if speaker_kv_max_layers is not None else None
|
| 220 |
+
else:
|
| 221 |
+
speaker_kv_scale_val = None
|
| 222 |
+
speaker_kv_min_t_val = None
|
| 223 |
+
speaker_kv_max_layers_val = None
|
| 224 |
+
|
| 225 |
+
# Load speaker audio early so we can compute actual lengths for bucket selection
|
| 226 |
+
use_zero_speaker = not speaker_audio_path or speaker_audio_path == ""
|
| 227 |
+
speaker_audio = load_audio(speaker_audio_path).cuda() if not use_zero_speaker else None
|
| 228 |
+
|
| 229 |
+
if use_custom_shapes:
|
| 230 |
+
# Compute actual text byte length
|
| 231 |
+
actual_text_byte_length = len(text_prompt.encode("utf-8")) + 1 # +1 for BOS token
|
| 232 |
+
|
| 233 |
+
# Compute actual speaker latent length (audio_samples // 2048)
|
| 234 |
+
AE_DOWNSAMPLE_FACTOR = 2048
|
| 235 |
+
if speaker_audio is not None:
|
| 236 |
+
actual_speaker_latent_length = (speaker_audio.shape[-1] // AE_DOWNSAMPLE_FACTOR) // 4 * 4
|
| 237 |
+
else:
|
| 238 |
+
actual_speaker_latent_length = 0
|
| 239 |
+
|
| 240 |
+
# Find appropriate bucket sizes from comma-separated values
|
| 241 |
+
pad_to_max_text_length = find_min_bucket_gte(max_text_byte_length, actual_text_byte_length)
|
| 242 |
+
pad_to_max_speaker_latent_length = find_min_bucket_gte(max_speaker_latent_length, actual_speaker_latent_length)
|
| 243 |
+
sample_latent_length_val = int(sample_latent_length) if sample_latent_length.strip() else (DEFAULT_SAMPLE_LATENT_LENGTH or 640)
|
| 244 |
+
else:
|
| 245 |
+
pad_to_max_text_length = None
|
| 246 |
+
pad_to_max_speaker_latent_length = None
|
| 247 |
+
sample_latent_length_val = DEFAULT_SAMPLE_LATENT_LENGTH or 640
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
sample_fn = partial(
|
| 251 |
+
sample_euler_cfg_independent_guidances,
|
| 252 |
+
num_steps=num_steps_int,
|
| 253 |
+
cfg_scale_text=cfg_scale_text_val,
|
| 254 |
+
cfg_scale_speaker=cfg_scale_speaker_val,
|
| 255 |
+
cfg_min_t=cfg_min_t_val,
|
| 256 |
+
cfg_max_t=cfg_max_t_val,
|
| 257 |
+
truncation_factor=truncation_factor_val,
|
| 258 |
+
rescale_k=rescale_k_val,
|
| 259 |
+
rescale_sigma=rescale_sigma_val,
|
| 260 |
+
speaker_kv_scale=speaker_kv_scale_val,
|
| 261 |
+
speaker_kv_min_t=speaker_kv_min_t_val,
|
| 262 |
+
speaker_kv_max_layers=speaker_kv_max_layers_val,
|
| 263 |
+
sequence_length=sample_latent_length_val,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
audio_out, normalized_text = sample_pipeline(
|
| 267 |
+
model=active_model,
|
| 268 |
+
fish_ae=active_fish_ae,
|
| 269 |
+
pca_state=pca_state,
|
| 270 |
+
sample_fn=sample_fn,
|
| 271 |
+
text_prompt=text_prompt,
|
| 272 |
+
speaker_audio=speaker_audio,
|
| 273 |
+
rng_seed=rng_seed_int,
|
| 274 |
+
pad_to_max_text_length=pad_to_max_text_length,
|
| 275 |
+
pad_to_max_speaker_latent_length=pad_to_max_speaker_latent_length,
|
| 276 |
+
normalize_text=True,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
audio_to_save = audio_out[0].cpu()
|
| 280 |
+
|
| 281 |
+
stem = make_stem("generated", session_id)
|
| 282 |
+
output_path = save_audio_with_format(audio_to_save, TEMP_AUDIO_DIR, stem, 44100, audio_format)
|
| 283 |
+
|
| 284 |
+
generation_time = time.time() - start_time
|
| 285 |
+
time_str = f"β±οΈ Total generation time: {generation_time:.2f}s"
|
| 286 |
+
text_display = f"**Text Prompt (normalized):**\n\n{normalized_text}"
|
| 287 |
+
|
| 288 |
+
recon_output_path = None
|
| 289 |
+
original_output_path = None
|
| 290 |
+
|
| 291 |
+
if reconstruct_first_30_seconds and speaker_audio is not None:
|
| 292 |
+
audio_recon = ae_reconstruct(
|
| 293 |
+
fish_ae=fish_ae,
|
| 294 |
+
pca_state=pca_state,
|
| 295 |
+
audio=torch.nn.functional.pad(
|
| 296 |
+
speaker_audio[..., :2048 * 640],
|
| 297 |
+
(0, max(0, 2048 * 640 - speaker_audio.shape[-1])),
|
| 298 |
+
)[None],
|
| 299 |
+
)[..., : speaker_audio.shape[-1]]
|
| 300 |
+
|
| 301 |
+
recon_stem = make_stem("speaker_recon", session_id)
|
| 302 |
+
recon_output_path = save_audio_with_format(audio_recon.cpu()[0], TEMP_AUDIO_DIR, recon_stem, 44100, audio_format)
|
| 303 |
+
|
| 304 |
+
if show_original_audio and speaker_audio is not None:
|
| 305 |
+
original_stem = make_stem("original_audio", session_id)
|
| 306 |
+
original_output_path = save_audio_with_format(speaker_audio.cpu(), TEMP_AUDIO_DIR, original_stem, 44100, audio_format)
|
| 307 |
+
|
| 308 |
+
show_reference_section = (show_original_audio or reconstruct_first_30_seconds) and speaker_audio is not None
|
| 309 |
+
return (
|
| 310 |
+
gr.update(),
|
| 311 |
+
gr.update(value=str(output_path), visible=True),
|
| 312 |
+
gr.update(value=text_display, visible=True),
|
| 313 |
+
gr.update(value=str(original_output_path) if original_output_path else None, visible=True),
|
| 314 |
+
gr.update(value=time_str, visible=True),
|
| 315 |
+
gr.update(value=str(recon_output_path) if recon_output_path else None, visible=True),
|
| 316 |
+
gr.update(visible=(show_original_audio and speaker_audio is not None)),
|
| 317 |
+
gr.update(visible=(reconstruct_first_30_seconds and speaker_audio is not None)),
|
| 318 |
+
gr.update(visible=show_reference_section),
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# UI Helper Functions
|
| 323 |
+
def load_text_presets():
|
| 324 |
+
"""Load text presets from file with category and word count."""
|
| 325 |
+
if TEXT_PRESETS_PATH.exists():
|
| 326 |
+
with open(TEXT_PRESETS_PATH, "r", encoding="utf-8") as f:
|
| 327 |
+
lines = [line.strip() for line in f if line.strip()]
|
| 328 |
+
|
| 329 |
+
result = []
|
| 330 |
+
for line in lines:
|
| 331 |
+
if " | " in line:
|
| 332 |
+
parts = line.split(" | ", 1)
|
| 333 |
+
category = parts[0]
|
| 334 |
+
text = parts[1]
|
| 335 |
+
else:
|
| 336 |
+
category = "Uncategorized"
|
| 337 |
+
text = line
|
| 338 |
+
|
| 339 |
+
word_count = len(text.split())
|
| 340 |
+
result.append([category, str(word_count), text])
|
| 341 |
+
|
| 342 |
+
return result
|
| 343 |
+
return []
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def select_text_preset(evt: gr.SelectData):
|
| 347 |
+
"""Handle text preset selection - extract text from the row."""
|
| 348 |
+
if evt.value:
|
| 349 |
+
if isinstance(evt.index, (tuple, list)) and len(evt.index) >= 2:
|
| 350 |
+
row_index = evt.index[0]
|
| 351 |
+
else:
|
| 352 |
+
row_index = evt.index
|
| 353 |
+
|
| 354 |
+
presets_data = load_text_presets()
|
| 355 |
+
if isinstance(row_index, int) and row_index < len(presets_data):
|
| 356 |
+
text = presets_data[row_index][2]
|
| 357 |
+
return gr.update(value=text)
|
| 358 |
+
return gr.update()
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def toggle_mode(mode):
|
| 362 |
+
"""Toggle advanced settings section visibility."""
|
| 363 |
+
show_advanced = mode == "Advanced Mode"
|
| 364 |
+
return gr.update(visible=show_advanced)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def update_force_row(force_speaker):
|
| 368 |
+
"""Show KV scaling controls when Force Speaker is enabled."""
|
| 369 |
+
return gr.update(visible=bool(force_speaker))
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def apply_cfg_preset(preset_name):
|
| 373 |
+
"""Apply CFG guidance preset."""
|
| 374 |
+
presets = {
|
| 375 |
+
"higher speaker": (3.0, 8.0, 0.5, 1.0),
|
| 376 |
+
"large guidances": (8.0, 8.0, 0.5, 1.0),
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
if preset_name not in presets:
|
| 380 |
+
return [gr.update()] * 5
|
| 381 |
+
|
| 382 |
+
text_scale, speaker_scale, min_t, max_t = presets[preset_name]
|
| 383 |
+
|
| 384 |
+
return [
|
| 385 |
+
gr.update(value=text_scale),
|
| 386 |
+
gr.update(value=speaker_scale),
|
| 387 |
+
gr.update(value=min_t),
|
| 388 |
+
gr.update(value=max_t),
|
| 389 |
+
gr.update(value="Custom"),
|
| 390 |
+
]
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def apply_speaker_kv_preset(preset_name):
|
| 394 |
+
"""Apply speaker KV attention control preset."""
|
| 395 |
+
if preset_name == "enable":
|
| 396 |
+
return [
|
| 397 |
+
gr.update(value=True),
|
| 398 |
+
gr.update(visible=True),
|
| 399 |
+
gr.update(value="Custom"),
|
| 400 |
+
]
|
| 401 |
+
if preset_name == "off":
|
| 402 |
+
return [
|
| 403 |
+
gr.update(value=False),
|
| 404 |
+
gr.update(visible=False),
|
| 405 |
+
gr.update(value="Custom"),
|
| 406 |
+
]
|
| 407 |
+
return [gr.update()] * 3
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def apply_truncation_preset(preset_name):
|
| 411 |
+
"""Apply truncation & temporal rescaling preset."""
|
| 412 |
+
presets = {
|
| 413 |
+
"flat": (0.8, 1.2, 3.0),
|
| 414 |
+
"sharp": (0.9, 0.96, 3.0),
|
| 415 |
+
"baseline(sharp)": (1.0, 1.0, 3.0),
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
if preset_name == "custom" or preset_name not in presets:
|
| 419 |
+
return [gr.update()] * 4
|
| 420 |
+
|
| 421 |
+
truncation, rescale_k, rescale_sigma = presets[preset_name]
|
| 422 |
+
|
| 423 |
+
return [
|
| 424 |
+
gr.update(value=truncation),
|
| 425 |
+
gr.update(value=rescale_k),
|
| 426 |
+
gr.update(value=rescale_sigma),
|
| 427 |
+
gr.update(value="Custom"),
|
| 428 |
+
]
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def load_sampler_presets():
|
| 432 |
+
"""Load sampler presets from JSON file."""
|
| 433 |
+
if SAMPLER_PRESETS_PATH.exists():
|
| 434 |
+
with open(SAMPLER_PRESETS_PATH, "r") as f:
|
| 435 |
+
return json.load(f)
|
| 436 |
+
|
| 437 |
+
default_presets = {
|
| 438 |
+
"Independent-High-Speaker-CFG": {
|
| 439 |
+
"num_steps": "40",
|
| 440 |
+
"cfg_scale_text": "3.0",
|
| 441 |
+
"cfg_scale_speaker": "8.0",
|
| 442 |
+
"cfg_min_t": "0.5",
|
| 443 |
+
"cfg_max_t": "1.0",
|
| 444 |
+
"truncation_factor": "1.",
|
| 445 |
+
"rescale_k": "1.",
|
| 446 |
+
"rescale_sigma": "3.0"
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
with open(SAMPLER_PRESETS_PATH, "w") as f:
|
| 450 |
+
json.dump(default_presets, f, indent=2)
|
| 451 |
+
return default_presets
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def apply_sampler_preset(preset_name):
|
| 455 |
+
"""Apply a sampler preset to all fields."""
|
| 456 |
+
presets = load_sampler_presets()
|
| 457 |
+
if preset_name == "Custom" or preset_name not in presets:
|
| 458 |
+
return [gr.update()] * 13
|
| 459 |
+
|
| 460 |
+
preset = presets[preset_name]
|
| 461 |
+
speaker_kv_enabled = to_bool(preset.get("speaker_kv_enable", False))
|
| 462 |
+
|
| 463 |
+
def to_num(val, default):
|
| 464 |
+
try:
|
| 465 |
+
return float(val) if isinstance(val, str) else val
|
| 466 |
+
except (ValueError, TypeError):
|
| 467 |
+
return default
|
| 468 |
+
|
| 469 |
+
return [
|
| 470 |
+
gr.update(value=int(to_num(preset.get("num_steps", "40"), 40))),
|
| 471 |
+
gr.update(value=to_num(preset.get("cfg_scale_text", "3.0"), 3.0)),
|
| 472 |
+
gr.update(value=to_num(preset.get("cfg_scale_speaker", "5.0"), 5.0)),
|
| 473 |
+
gr.update(value=to_num(preset.get("cfg_min_t", "0.5"), 0.5)),
|
| 474 |
+
gr.update(value=to_num(preset.get("cfg_max_t", "1.0"), 1.0)),
|
| 475 |
+
gr.update(value=to_num(preset.get("truncation_factor", "0.8"), 0.8)),
|
| 476 |
+
gr.update(value=to_num(preset.get("rescale_k", "1.2"), 1.2)),
|
| 477 |
+
gr.update(value=to_num(preset.get("rescale_sigma", "3.0"), 3.0)),
|
| 478 |
+
gr.update(value=speaker_kv_enabled),
|
| 479 |
+
gr.update(visible=speaker_kv_enabled),
|
| 480 |
+
gr.update(value=to_num(preset.get("speaker_kv_scale", "1.5"), 1.5)),
|
| 481 |
+
gr.update(value=to_num(preset.get("speaker_kv_min_t", "0.9"), 0.9)),
|
| 482 |
+
gr.update(value=int(to_num(preset.get("speaker_kv_max_layers", "24"), 24))),
|
| 483 |
+
]
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
AUDIO_EXTS = {".wav", ".mp3", ".m4a", ".ogg", ".flac", ".webm", ".aac", ".opus"}
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def get_audio_prompt_files(search_query: str = ""):
|
| 490 |
+
"""Get list of audio files from the audio prompt folder, optionally filtered by search query."""
|
| 491 |
+
if AUDIO_PROMPT_FOLDER is None or not AUDIO_PROMPT_FOLDER.exists():
|
| 492 |
+
return []
|
| 493 |
+
|
| 494 |
+
files = sorted([f.name for f in AUDIO_PROMPT_FOLDER.iterdir() if f.is_file() and f.suffix.lower() in AUDIO_EXTS], key=str.lower)
|
| 495 |
+
|
| 496 |
+
# Filter by search query if provided
|
| 497 |
+
if search_query.strip():
|
| 498 |
+
query_lower = search_query.lower()
|
| 499 |
+
files = [f for f in files if query_lower in f.lower()]
|
| 500 |
+
|
| 501 |
+
return [[file] for file in files]
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def filter_audio_prompts(search_query: str):
|
| 505 |
+
"""Filter audio prompts based on search query."""
|
| 506 |
+
return gr.update(value=get_audio_prompt_files(search_query))
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def select_audio_prompt_file(evt: gr.SelectData):
|
| 510 |
+
"""Handle audio prompt file selection from table."""
|
| 511 |
+
if evt.value and AUDIO_PROMPT_FOLDER is not None:
|
| 512 |
+
file_path = AUDIO_PROMPT_FOLDER / evt.value
|
| 513 |
+
if file_path.exists():
|
| 514 |
+
return gr.update(value=str(file_path))
|
| 515 |
+
return gr.update()
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
# UI styling and helpers
|
| 519 |
+
LINK_CSS = """
|
| 520 |
+
.preset-inline { display:flex; align-items:baseline; gap:6px; margin-top:-4px; margin-bottom:-12px; }
|
| 521 |
+
.preset-inline .title { font-weight:600; font-size:.95rem; }
|
| 522 |
+
.preset-inline .dim { color:#666; margin:0 4px; }
|
| 523 |
+
a.preset-link { color: #0a5bd8; text-decoration: underline; cursor: pointer; font-weight: 400; }
|
| 524 |
+
a.preset-link:hover { text-decoration: none; opacity: 0.8; }
|
| 525 |
+
.dark a.preset-link,
|
| 526 |
+
[data-theme="dark"] a.preset-link { color: #60a5fa !important; }
|
| 527 |
+
.dark a.preset-link:hover,
|
| 528 |
+
[data-theme="dark"] a.preset-link:hover { color: #93c5fd !important; }
|
| 529 |
+
.dark .preset-inline .dim,
|
| 530 |
+
[data-theme="dark"] .preset-inline .dim { color: #9ca3af !important; }
|
| 531 |
+
.proxy-btn { position:absolute; width:0; height:0; overflow:hidden; padding:0 !important; margin:0 !important; border:0 !important; opacity:0; pointer-events:none; }
|
| 532 |
+
.gr-group { border: 1px solid #d1d5db !important; background: #f3f4f6 !important; }
|
| 533 |
+
.dark .gr-group,
|
| 534 |
+
[data-theme="dark"] .gr-group { border: 1px solid #4b5563 !important; background: #1f2937 !important; }
|
| 535 |
+
.generated-audio-player { border: 3px solid #667eea !important; border-radius: 12px !important; padding: 20px !important; background: linear-gradient(135deg, rgba(102, 126, 234, 0.08) 0%, rgba(118, 75, 162, 0.05) 100%) !important; box-shadow: 0 4px 12px rgba(102, 126, 234, 0.2) !important; margin: 1rem 0 !important; }
|
| 536 |
+
.generated-audio-player > div { background: transparent !important; }
|
| 537 |
+
#component-mode-selector { text-align: center; padding: 1rem 0; }
|
| 538 |
+
#component-mode-selector label { font-size: 1.1rem !important; font-weight: 600 !important; margin-bottom: 0.5rem !important; }
|
| 539 |
+
#component-mode-selector .wrap { justify-content: center !important; }
|
| 540 |
+
#component-mode-selector fieldset { border: 2px solid #e5e7eb !important; border-radius: 8px !important; padding: 1rem !important; background: #f9fafb !important; }
|
| 541 |
+
.dark #component-mode-selector fieldset,
|
| 542 |
+
[data-theme="dark"] #component-mode-selector fieldset { border: 2px solid #4b5563 !important; background: #1f2937 !important; }
|
| 543 |
+
.section-separator { height: 3px !important; background: linear-gradient(90deg, transparent 0%, #667eea 20%, #764ba2 80%, transparent 100%) !important; border: none !important; margin: 2rem 0 !important; }
|
| 544 |
+
.dark .section-separator,
|
| 545 |
+
[data-theme="dark"] .section-separator { background: linear-gradient(90deg, transparent 0%, #667eea 20%, #764ba2 80%, transparent 100%) !important; }
|
| 546 |
+
.gradio-container h1,
|
| 547 |
+
.gradio-container h2 { font-weight: 700 !important; margin-top: 1.5rem !important; margin-bottom: 1rem !important; }
|
| 548 |
+
.tip-box { background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%) !important; border-left: 4px solid #f59e0b !important; border-radius: 8px !important; padding: 1rem 1.5rem !important; margin: 1rem 0 !important; box-shadow: 0 2px 4px rgba(245, 158, 11, 0.1) !important; }
|
| 549 |
+
.tip-box strong { color: #92400e !important; }
|
| 550 |
+
.dark .tip-box,
|
| 551 |
+
[data-theme="dark"] .tip-box { background: linear-gradient(135deg, #451a03 0%, #78350f 100%) !important; border-left: 4px solid #f59e0b !important; }
|
| 552 |
+
.dark .tip-box strong,
|
| 553 |
+
[data-theme="dark"] .tip-box strong { color: #fbbf24 !important; }
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
+
JS_CODE = r"""
|
| 557 |
+
function () {
|
| 558 |
+
const appEl = document.querySelector("gradio-app");
|
| 559 |
+
const root = appEl && appEl.shadowRoot ? appEl.shadowRoot : document;
|
| 560 |
+
function clickHiddenButtonById(id) {
|
| 561 |
+
if (!id) return;
|
| 562 |
+
const host = root.getElementById(id);
|
| 563 |
+
if (!host) return;
|
| 564 |
+
const realBtn = host.querySelector("button, [role='button']") || host;
|
| 565 |
+
realBtn.click();
|
| 566 |
+
}
|
| 567 |
+
root.addEventListener("click", (ev) => {
|
| 568 |
+
const a = ev.target.closest("a.preset-link");
|
| 569 |
+
if (!a) return;
|
| 570 |
+
ev.preventDefault();
|
| 571 |
+
ev.stopPropagation();
|
| 572 |
+
ev.stopImmediatePropagation();
|
| 573 |
+
clickHiddenButtonById(a.getAttribute("data-fire"));
|
| 574 |
+
return false;
|
| 575 |
+
}, true);
|
| 576 |
+
}
|
| 577 |
+
"""
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def init_session():
|
| 581 |
+
"""Initialize session ID for this browser tab/session."""
|
| 582 |
+
return secrets.token_hex(8)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
with gr.Blocks(title="Echo-TTS", css=LINK_CSS, js=JS_CODE) as demo:
|
| 586 |
+
gr.Markdown("# Echo-TTS")
|
| 587 |
+
gr.Markdown("*Jordan Darefsky, 2025. See technical details [here](https://jordandarefsky.com/blog/2025/echo/)*")
|
| 588 |
+
|
| 589 |
+
gr.Markdown("**License Notice:** All audio outputs are subject to non-commercial use [CC-BY-NC-SA-4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/).")
|
| 590 |
+
|
| 591 |
+
gr.Markdown("**Responsible Use:** Do not use this model to impersonate real people without their explicit consent or to generate deceptive audio.")
|
| 592 |
+
|
| 593 |
+
with gr.Accordion("π Quick Start Instructions", open=True):
|
| 594 |
+
gr.Markdown(
|
| 595 |
+
"""
|
| 596 |
+
1. Upload or record a short reference clip (or leave blank for no speaker reference).
|
| 597 |
+
2. Pick a text preset or type your own prompt.
|
| 598 |
+
3. Click **Generate Audio**.
|
| 599 |
+
|
| 600 |
+
<div class="tip-box">
|
| 601 |
+
π‘ **Tip:** If the generated voice does not match the reference, enable "Force Speaker" and regenerate.
|
| 602 |
+
</div>
|
| 603 |
+
"""
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
session_id_state = gr.State(None)
|
| 607 |
+
|
| 608 |
+
gr.Markdown("# Speaker Reference")
|
| 609 |
+
with gr.Row():
|
| 610 |
+
if AUDIO_PROMPT_FOLDER is not None and AUDIO_PROMPT_FOLDER.exists():
|
| 611 |
+
with gr.Column(scale=1, min_width=200):
|
| 612 |
+
gr.Markdown("#### Audio Library (click to load)")
|
| 613 |
+
audio_prompt_search = gr.Textbox(
|
| 614 |
+
label="",
|
| 615 |
+
placeholder="π Search audio prompts...",
|
| 616 |
+
lines=1,
|
| 617 |
+
max_lines=1,
|
| 618 |
+
)
|
| 619 |
+
audio_prompt_table = gr.Dataframe(
|
| 620 |
+
value=get_audio_prompt_files(),
|
| 621 |
+
headers=["Filename"],
|
| 622 |
+
datatype=["str"],
|
| 623 |
+
row_count=(10, "dynamic"),
|
| 624 |
+
col_count=(1, "fixed"),
|
| 625 |
+
interactive=False,
|
| 626 |
+
label="",
|
| 627 |
+
)
|
| 628 |
+
with gr.Column(scale=2):
|
| 629 |
+
custom_audio_input = gr.Audio(
|
| 630 |
+
sources=["upload", "microphone"],
|
| 631 |
+
type="filepath",
|
| 632 |
+
label="Speaker Reference Audio (first five minutes used; blank for no speaker reference)",
|
| 633 |
+
max_length=600,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
gr.HTML('<hr class="section-separator">')
|
| 637 |
+
gr.Markdown("# Text Prompt")
|
| 638 |
+
with gr.Accordion("Text Presets", open=True):
|
| 639 |
+
text_presets_table = gr.Dataframe(
|
| 640 |
+
value=load_text_presets(),
|
| 641 |
+
headers=["Category", "Words", "Preset Text"],
|
| 642 |
+
datatype=["str", "str", "str"],
|
| 643 |
+
row_count=(3, "dynamic"),
|
| 644 |
+
col_count=(3, "fixed"),
|
| 645 |
+
interactive=False,
|
| 646 |
+
column_widths=["12%", "6%", "82%"],
|
| 647 |
+
)
|
| 648 |
+
text_prompt = gr.Textbox(label="Text Prompt", placeholder="[S1] Enter your text prompt here...", lines=4)
|
| 649 |
+
|
| 650 |
+
gr.HTML('<hr class="section-separator">')
|
| 651 |
+
gr.Markdown("# Generation")
|
| 652 |
+
|
| 653 |
+
with gr.Row():
|
| 654 |
+
with gr.Column(scale=1):
|
| 655 |
+
pass
|
| 656 |
+
with gr.Column(scale=2):
|
| 657 |
+
mode_selector = gr.Radio(
|
| 658 |
+
choices=["Simple Mode", "Advanced Mode"],
|
| 659 |
+
value="Simple Mode",
|
| 660 |
+
label="",
|
| 661 |
+
info=None,
|
| 662 |
+
elem_id="component-mode-selector",
|
| 663 |
+
)
|
| 664 |
+
with gr.Column(scale=1):
|
| 665 |
+
pass
|
| 666 |
+
|
| 667 |
+
with gr.Accordion("βοΈ Generation Parameters", open=True):
|
| 668 |
+
with gr.Row(equal_height=False):
|
| 669 |
+
presets = load_sampler_presets()
|
| 670 |
+
preset_keys = list(presets.keys())
|
| 671 |
+
first_preset = preset_keys[0] if preset_keys else "Custom"
|
| 672 |
+
|
| 673 |
+
with gr.Column(scale=2):
|
| 674 |
+
preset_dropdown = gr.Dropdown(
|
| 675 |
+
choices=["Custom"] + preset_keys,
|
| 676 |
+
value=first_preset,
|
| 677 |
+
label="Sampler Preset",
|
| 678 |
+
info="Load preset configurations",
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
with gr.Column(scale=0.8, min_width=100):
|
| 682 |
+
num_steps = gr.Number(
|
| 683 |
+
label="Steps",
|
| 684 |
+
value=40,
|
| 685 |
+
info="Sampling steps (Try 20-80)",
|
| 686 |
+
precision=0,
|
| 687 |
+
minimum=5,
|
| 688 |
+
step=5,
|
| 689 |
+
maximum=80,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
with gr.Column(scale=0.8, min_width=100):
|
| 693 |
+
rng_seed = gr.Number(label="RNG Seed", value=0, info="Seed for noise", precision=0)
|
| 694 |
+
|
| 695 |
+
with gr.Column(scale=3):
|
| 696 |
+
with gr.Group():
|
| 697 |
+
gr.HTML(
|
| 698 |
+
"""
|
| 699 |
+
<div class="preset-inline">
|
| 700 |
+
<span class="title">Speaker KV Attention Scaling</span>
|
| 701 |
+
</div>
|
| 702 |
+
"""
|
| 703 |
+
)
|
| 704 |
+
spk_kv_preset_enable = gr.Button("", elem_id="spk_kv_enable", elem_classes=["proxy-btn"])
|
| 705 |
+
spk_kv_preset_off = gr.Button("", elem_id="spk_kv_off", elem_classes=["proxy-btn"])
|
| 706 |
+
force_speaker = gr.Checkbox(
|
| 707 |
+
label='"Force Speaker" (KV scaling)',
|
| 708 |
+
value=False,
|
| 709 |
+
info="Enable to more strongly match the reference speaker (though higher values may degrade quality)",
|
| 710 |
+
)
|
| 711 |
+
with gr.Row(visible=False) as speaker_kv_row:
|
| 712 |
+
speaker_kv_scale = gr.Number(label="KV Scale", value=1.5, info="Scale factor (>1 -> larger effect; try 1.5, 1.2, ...)", minimum=0, step=0.1)
|
| 713 |
+
speaker_kv_min_t = gr.Number(
|
| 714 |
+
label="KV Min t",
|
| 715 |
+
value=0.9,
|
| 716 |
+
info="(0-1), scale applied from steps t=1. to val",
|
| 717 |
+
minimum=0,
|
| 718 |
+
maximum=1,
|
| 719 |
+
step=0.05,
|
| 720 |
+
)
|
| 721 |
+
speaker_kv_max_layers = gr.Number(
|
| 722 |
+
label="Max Layers",
|
| 723 |
+
value=24,
|
| 724 |
+
info="(0-24), scale applied in first N layers",
|
| 725 |
+
precision=0,
|
| 726 |
+
minimum=0,
|
| 727 |
+
maximum=24,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
with gr.Column(visible=False) as advanced_mode_column:
|
| 731 |
+
compile_checkbox = gr.Checkbox(
|
| 732 |
+
label="Compile Model",
|
| 733 |
+
value=False,
|
| 734 |
+
info="Compile for faster runs (~10-30% faster); forces Custom Shapes on to avoid excessive recompilation.",
|
| 735 |
+
)
|
| 736 |
+
use_custom_shapes_checkbox = gr.Checkbox(
|
| 737 |
+
label="Use Custom Shapes (Advanced)",
|
| 738 |
+
value=False,
|
| 739 |
+
info="Override default generation length and/or force latent and text padding (if unchecked, no padding is used and latent generation length is 640β30s.)",
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
with gr.Row(visible=False) as custom_shapes_row:
|
| 743 |
+
max_text_byte_length = gr.Textbox(
|
| 744 |
+
label="Max Text Byte Length (padded)",
|
| 745 |
+
value="768",
|
| 746 |
+
info="Single value or comma-separated buckets (auto-selects min >= length); 768 = max; leave blank for no padding",
|
| 747 |
+
scale=1,
|
| 748 |
+
)
|
| 749 |
+
max_speaker_latent_length = gr.Textbox(
|
| 750 |
+
label="Max Speaker Latent Length (padded)",
|
| 751 |
+
value="640, 2816, 6400",
|
| 752 |
+
info="Single value or comma-separated buckets (auto-selects min >= length); 640β30s, 2560β2min, 6400β5min (max); leave blank for no padding",
|
| 753 |
+
scale=1,
|
| 754 |
+
)
|
| 755 |
+
sample_latent_length = gr.Textbox(
|
| 756 |
+
label="Sample Latent Length",
|
| 757 |
+
value=str(DEFAULT_SAMPLE_LATENT_LENGTH),
|
| 758 |
+
info="Maximum sample latent length (640β30s max seen during training; smaller works well for generating prefixes)",
|
| 759 |
+
scale=1,
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
with gr.Row():
|
| 763 |
+
with gr.Column(scale=1):
|
| 764 |
+
with gr.Group():
|
| 765 |
+
gr.HTML(
|
| 766 |
+
"""
|
| 767 |
+
<div class="preset-inline">
|
| 768 |
+
<span class="title">Truncation & Temporal Rescaling</span><span class="dim">(</span>
|
| 769 |
+
<a href="javascript:void(0)" class="preset-link" data-fire="trunc_flat">flat</a>
|
| 770 |
+
<span class="dim">,</span>
|
| 771 |
+
<a href="javascript:void(0)" class="preset-link" data-fire="trunc_sharp">sharp</a>
|
| 772 |
+
<span class="dim">,</span>
|
| 773 |
+
<a href="javascript:void(0)" class="preset-link" data-fire="trunc_baseline">baseline(sharp)</a>
|
| 774 |
+
<span class="dim">)</span>
|
| 775 |
+
</div>
|
| 776 |
+
"""
|
| 777 |
+
)
|
| 778 |
+
trunc_preset_flat = gr.Button("", elem_id="trunc_flat", elem_classes=["proxy-btn"])
|
| 779 |
+
trunc_preset_sharp = gr.Button("", elem_id="trunc_sharp", elem_classes=["proxy-btn"])
|
| 780 |
+
trunc_preset_baseline = gr.Button("", elem_id="trunc_baseline", elem_classes=["proxy-btn"])
|
| 781 |
+
with gr.Row():
|
| 782 |
+
truncation_factor = gr.Number(
|
| 783 |
+
label="Truncation Factor",
|
| 784 |
+
value=0.8,
|
| 785 |
+
info="Multiply initial noise (<1 helps artifacts)",
|
| 786 |
+
minimum=0,
|
| 787 |
+
step=0.05,
|
| 788 |
+
)
|
| 789 |
+
rescale_k = gr.Number(
|
| 790 |
+
label="Rescale k", value=1.2, info="<1=sharpen, >1=flatten, 1=off", minimum=0, step=0.05
|
| 791 |
+
)
|
| 792 |
+
rescale_sigma = gr.Number(
|
| 793 |
+
label="Rescale Ο", value=3.0, info="Sigma parameter", minimum=0, step=0.1
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
with gr.Column(scale=1):
|
| 797 |
+
with gr.Group():
|
| 798 |
+
gr.HTML(
|
| 799 |
+
"""
|
| 800 |
+
<div class="preset-inline">
|
| 801 |
+
<span class="title">CFG Guidance</span><span class="dim">(</span>
|
| 802 |
+
<a href="javascript:void(0)" class="preset-link" data-fire="cfg_higher">higher speaker</a>
|
| 803 |
+
<span class="dim">,</span>
|
| 804 |
+
<a href="javascript:void(0)" class="preset-link" data-fire="cfg_large">large guidances</a>
|
| 805 |
+
<span class="dim">)</span>
|
| 806 |
+
</div>
|
| 807 |
+
"""
|
| 808 |
+
)
|
| 809 |
+
cfg_preset_higher_speaker = gr.Button("", elem_id="cfg_higher", elem_classes=["proxy-btn"])
|
| 810 |
+
cfg_preset_large_guidances = gr.Button("", elem_id="cfg_large", elem_classes=["proxy-btn"])
|
| 811 |
+
with gr.Row():
|
| 812 |
+
cfg_scale_text = gr.Number(
|
| 813 |
+
label="Text CFG Scale", value=3.0, info="Guidance strength for text", minimum=0, step=0.5
|
| 814 |
+
)
|
| 815 |
+
cfg_scale_speaker = gr.Number(
|
| 816 |
+
label="Speaker CFG Scale",
|
| 817 |
+
value=5.0,
|
| 818 |
+
info="Guidance strength for speaker",
|
| 819 |
+
minimum=0,
|
| 820 |
+
step=0.5,
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
with gr.Row():
|
| 824 |
+
cfg_min_t = gr.Number(
|
| 825 |
+
label="CFG Min t", value=0.5, info="(0-1), CFG applied when t >= val", minimum=0, maximum=1, step=0.05
|
| 826 |
+
)
|
| 827 |
+
cfg_max_t = gr.Number(
|
| 828 |
+
label="CFG Max t", value=1.0, info="(0-1), CFG applied when t <= val", minimum=0, maximum=1, step=0.05
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
with gr.Row(equal_height=True):
|
| 832 |
+
audio_format = gr.Radio(choices=["wav", "mp3"], value="wav", label="Format", scale=1, min_width=90)
|
| 833 |
+
generate_btn = gr.Button("Generate Audio", variant="primary", size="lg", scale=10)
|
| 834 |
+
with gr.Column(scale=1):
|
| 835 |
+
show_original_audio = gr.Checkbox(label="Re-display Original Audio (full 5-minute cropped mono)", value=False)
|
| 836 |
+
reconstruct_first_30_seconds = gr.Checkbox(
|
| 837 |
+
label="Show Autoencoder Reconstruction (only first 30s of reference)", value=False
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
gr.HTML('<hr class="section-separator">')
|
| 841 |
+
with gr.Accordion("Generated Audio", open=True, visible=True) as generated_section:
|
| 842 |
+
generation_time_display = gr.Markdown("", visible=False)
|
| 843 |
+
with gr.Group(elem_classes=["generated-audio-player"]):
|
| 844 |
+
generated_audio = gr.Audio(label="Generated Audio", visible=True)
|
| 845 |
+
text_prompt_display = gr.Markdown("", visible=False)
|
| 846 |
+
|
| 847 |
+
gr.Markdown("---")
|
| 848 |
+
reference_audio_header = gr.Markdown("#### Reference Audio", visible=False)
|
| 849 |
+
|
| 850 |
+
with gr.Accordion("Original Audio (5 min Cropped Mono)", open=False, visible=False) as original_accordion:
|
| 851 |
+
original_audio = gr.Audio(label="Original Reference Audio (5 min)", visible=True)
|
| 852 |
+
|
| 853 |
+
with gr.Accordion("Autoencoder Reconstruction of First 30s of Reference", open=False, visible=False) as reference_accordion:
|
| 854 |
+
reference_audio = gr.Audio(label="Decoded Reference Audio (30s)", visible=True)
|
| 855 |
+
|
| 856 |
+
# Event handlers
|
| 857 |
+
if AUDIO_PROMPT_FOLDER is not None and AUDIO_PROMPT_FOLDER.exists():
|
| 858 |
+
audio_prompt_table.select(select_audio_prompt_file, outputs=[custom_audio_input])
|
| 859 |
+
audio_prompt_search.change(filter_audio_prompts, inputs=[audio_prompt_search], outputs=[audio_prompt_table])
|
| 860 |
+
|
| 861 |
+
text_presets_table.select(select_text_preset, outputs=text_prompt)
|
| 862 |
+
|
| 863 |
+
mode_selector.change(toggle_mode, inputs=[mode_selector], outputs=[advanced_mode_column])
|
| 864 |
+
|
| 865 |
+
force_speaker.change(update_force_row, inputs=[force_speaker], outputs=[speaker_kv_row])
|
| 866 |
+
|
| 867 |
+
def toggle_custom_shapes(enabled):
|
| 868 |
+
return gr.update(visible=enabled)
|
| 869 |
+
|
| 870 |
+
use_custom_shapes_checkbox.change(
|
| 871 |
+
toggle_custom_shapes,
|
| 872 |
+
inputs=[use_custom_shapes_checkbox],
|
| 873 |
+
outputs=[custom_shapes_row],
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
def on_compile_change(compile_enabled):
|
| 877 |
+
"""When compile is enabled, force custom shapes to be enabled."""
|
| 878 |
+
if compile_enabled:
|
| 879 |
+
return (
|
| 880 |
+
gr.update(value=True), # use_custom_shapes_checkbox
|
| 881 |
+
gr.update(visible=True), # custom_shapes_row
|
| 882 |
+
)
|
| 883 |
+
return (
|
| 884 |
+
gr.update(),
|
| 885 |
+
gr.update(),
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
compile_checkbox.change(
|
| 889 |
+
on_compile_change,
|
| 890 |
+
inputs=[compile_checkbox],
|
| 891 |
+
outputs=[use_custom_shapes_checkbox, custom_shapes_row],
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
cfg_preset_higher_speaker.click(
|
| 895 |
+
lambda: apply_cfg_preset("higher speaker"), outputs=[cfg_scale_text, cfg_scale_speaker, cfg_min_t, cfg_max_t, preset_dropdown]
|
| 896 |
+
)
|
| 897 |
+
cfg_preset_large_guidances.click(
|
| 898 |
+
lambda: apply_cfg_preset("large guidances"), outputs=[cfg_scale_text, cfg_scale_speaker, cfg_min_t, cfg_max_t, preset_dropdown]
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
spk_kv_preset_enable.click(lambda: apply_speaker_kv_preset("enable"), outputs=[force_speaker, speaker_kv_row, preset_dropdown])
|
| 902 |
+
spk_kv_preset_off.click(lambda: apply_speaker_kv_preset("off"), outputs=[force_speaker, speaker_kv_row, preset_dropdown])
|
| 903 |
+
|
| 904 |
+
trunc_preset_flat.click(lambda: apply_truncation_preset("flat"), outputs=[truncation_factor, rescale_k, rescale_sigma, preset_dropdown])
|
| 905 |
+
trunc_preset_sharp.click(lambda: apply_truncation_preset("sharp"), outputs=[truncation_factor, rescale_k, rescale_sigma, preset_dropdown])
|
| 906 |
+
trunc_preset_baseline.click(
|
| 907 |
+
lambda: apply_truncation_preset("baseline(sharp)"), outputs=[truncation_factor, rescale_k, rescale_sigma, preset_dropdown]
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
preset_dropdown.change(
|
| 911 |
+
apply_sampler_preset,
|
| 912 |
+
inputs=preset_dropdown,
|
| 913 |
+
outputs=[
|
| 914 |
+
num_steps,
|
| 915 |
+
cfg_scale_text,
|
| 916 |
+
cfg_scale_speaker,
|
| 917 |
+
cfg_min_t,
|
| 918 |
+
cfg_max_t,
|
| 919 |
+
truncation_factor,
|
| 920 |
+
rescale_k,
|
| 921 |
+
rescale_sigma,
|
| 922 |
+
force_speaker,
|
| 923 |
+
speaker_kv_row,
|
| 924 |
+
speaker_kv_scale,
|
| 925 |
+
speaker_kv_min_t,
|
| 926 |
+
speaker_kv_max_layers,
|
| 927 |
+
],
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
generate_btn.click(
|
| 931 |
+
generate_audio,
|
| 932 |
+
inputs=[
|
| 933 |
+
text_prompt,
|
| 934 |
+
custom_audio_input,
|
| 935 |
+
num_steps,
|
| 936 |
+
rng_seed,
|
| 937 |
+
cfg_scale_text,
|
| 938 |
+
cfg_scale_speaker,
|
| 939 |
+
cfg_min_t,
|
| 940 |
+
cfg_max_t,
|
| 941 |
+
truncation_factor,
|
| 942 |
+
rescale_k,
|
| 943 |
+
rescale_sigma,
|
| 944 |
+
force_speaker,
|
| 945 |
+
speaker_kv_scale,
|
| 946 |
+
speaker_kv_min_t,
|
| 947 |
+
speaker_kv_max_layers,
|
| 948 |
+
reconstruct_first_30_seconds,
|
| 949 |
+
use_custom_shapes_checkbox,
|
| 950 |
+
max_text_byte_length,
|
| 951 |
+
max_speaker_latent_length,
|
| 952 |
+
sample_latent_length,
|
| 953 |
+
audio_format,
|
| 954 |
+
compile_checkbox,
|
| 955 |
+
show_original_audio,
|
| 956 |
+
session_id_state,
|
| 957 |
+
],
|
| 958 |
+
outputs=[
|
| 959 |
+
generated_section,
|
| 960 |
+
generated_audio,
|
| 961 |
+
text_prompt_display,
|
| 962 |
+
original_audio,
|
| 963 |
+
generation_time_display,
|
| 964 |
+
reference_audio,
|
| 965 |
+
original_accordion,
|
| 966 |
+
reference_accordion,
|
| 967 |
+
reference_audio_header,
|
| 968 |
+
],
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
demo.load(init_session, outputs=[session_id_state]).then(
|
| 972 |
+
lambda: apply_sampler_preset(list(load_sampler_presets().keys())[0]),
|
| 973 |
+
outputs=[
|
| 974 |
+
num_steps,
|
| 975 |
+
cfg_scale_text,
|
| 976 |
+
cfg_scale_speaker,
|
| 977 |
+
cfg_min_t,
|
| 978 |
+
cfg_max_t,
|
| 979 |
+
truncation_factor,
|
| 980 |
+
rescale_k,
|
| 981 |
+
rescale_sigma,
|
| 982 |
+
force_speaker,
|
| 983 |
+
speaker_kv_row,
|
| 984 |
+
speaker_kv_scale,
|
| 985 |
+
speaker_kv_min_t,
|
| 986 |
+
speaker_kv_max_layers,
|
| 987 |
+
],
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
|
| 991 |
+
if __name__ == "__main__":
|
| 992 |
+
demo.launch(
|
| 993 |
+
allowed_paths=[str(AUDIO_PROMPT_FOLDER)]
|
| 994 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Callable, List, Tuple
|
| 3 |
+
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
import safetensors.torch as st
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
from torchcodec.decoders import AudioDecoder
|
| 9 |
+
|
| 10 |
+
from autoencoder import DAC, build_ae
|
| 11 |
+
from model import EchoDiT
|
| 12 |
+
|
| 13 |
+
def load_model_from_hf(repo_id: str = "jordand/echo-tts-base", device: str = "cuda", dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None, delete_blockwise_modules: bool = False) -> EchoDiT:
|
| 14 |
+
with torch.device("meta"):
|
| 15 |
+
model = EchoDiT(
|
| 16 |
+
latent_size=80, model_size=2048, num_layers=24, num_heads=16,
|
| 17 |
+
intermediate_size=5888, norm_eps=1e-5,
|
| 18 |
+
text_vocab_size=256, text_model_size=1280, text_num_layers=14,
|
| 19 |
+
text_num_heads=10, text_intermediate_size=3328,
|
| 20 |
+
speaker_patch_size=4, speaker_model_size=1280, speaker_num_layers=14,
|
| 21 |
+
speaker_num_heads=10, speaker_intermediate_size=3328,
|
| 22 |
+
timestep_embed_size=512, adaln_rank=256,
|
| 23 |
+
)
|
| 24 |
+
w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token)
|
| 25 |
+
state = st.load_file(w_path, device="cpu")
|
| 26 |
+
|
| 27 |
+
if delete_blockwise_modules:
|
| 28 |
+
state = {k: v for k, v in state.items() if not (
|
| 29 |
+
k.startswith("latent_encoder.") or
|
| 30 |
+
k.startswith("latent_norm") or
|
| 31 |
+
".wk_latent" in k or
|
| 32 |
+
".wv_latent" in k
|
| 33 |
+
)}
|
| 34 |
+
|
| 35 |
+
if dtype is not None:
|
| 36 |
+
state = {k: v.to(dtype=dtype) for k, v in state.items()}
|
| 37 |
+
|
| 38 |
+
state = {k: v.to(device=device) for k, v in state.items()}
|
| 39 |
+
|
| 40 |
+
model.load_state_dict(state, strict=False, assign=True)
|
| 41 |
+
model = model.eval()
|
| 42 |
+
|
| 43 |
+
if compile:
|
| 44 |
+
model = compile_model(model)
|
| 45 |
+
|
| 46 |
+
return model
|
| 47 |
+
|
| 48 |
+
def compile_model(model: EchoDiT) -> EchoDiT:
|
| 49 |
+
model = torch.compile(model)
|
| 50 |
+
model.get_kv_cache_text = torch.compile(model.get_kv_cache_text)
|
| 51 |
+
model.get_kv_cache_speaker = torch.compile(model.get_kv_cache_speaker)
|
| 52 |
+
model.get_kv_cache_latent = torch.compile(model.get_kv_cache_latent)
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
def load_fish_ae_from_hf(repo_id: str = "jordand/fish-s1-dac-min", device: str = "cuda", dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC:
|
| 56 |
+
|
| 57 |
+
with torch.device("meta"):
|
| 58 |
+
fish_ae = build_ae()
|
| 59 |
+
|
| 60 |
+
w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token)
|
| 61 |
+
if dtype is not None and dtype != torch.float32:
|
| 62 |
+
state = st.load_file(w_path, device="cpu")
|
| 63 |
+
state = {k: v.to(dtype=dtype) for k, v in state.items()}
|
| 64 |
+
state = {k: v.to(device=device) for k, v in state.items()}
|
| 65 |
+
fish_ae.load_state_dict(state, strict=False, assign=True)
|
| 66 |
+
else:
|
| 67 |
+
state = st.load_file(w_path, device=device)
|
| 68 |
+
fish_ae.load_state_dict(state, strict=False, assign=True)
|
| 69 |
+
|
| 70 |
+
fish_ae = fish_ae.eval().to(device)
|
| 71 |
+
|
| 72 |
+
if compile:
|
| 73 |
+
fish_ae = compile_fish_ae(fish_ae)
|
| 74 |
+
|
| 75 |
+
return fish_ae
|
| 76 |
+
|
| 77 |
+
def compile_fish_ae(fish_ae: DAC) -> DAC:
|
| 78 |
+
fish_ae.quantizer.upsample = torch.compile(fish_ae.quantizer.upsample)
|
| 79 |
+
fish_ae.quantizer.downsample = torch.compile(fish_ae.quantizer.downsample)
|
| 80 |
+
fish_ae.quantizer.pre_module = torch.compile(fish_ae.quantizer.pre_module)
|
| 81 |
+
fish_ae.quantizer.post_module = torch.compile(fish_ae.quantizer.post_module)
|
| 82 |
+
return fish_ae
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class PCAState:
|
| 87 |
+
pca_components: torch.Tensor
|
| 88 |
+
pca_mean: torch.Tensor
|
| 89 |
+
latent_scale: float
|
| 90 |
+
|
| 91 |
+
def load_pca_state_from_hf(repo_id: str = "jordand/echo-tts-base", device: str = "cuda", filename: str = "pca_state.safetensors", token: str | None = None) -> PCAState:
|
| 92 |
+
p_path = hf_hub_download(repo_id, filename, token=token)
|
| 93 |
+
t = st.load_file(p_path, device=device)
|
| 94 |
+
return PCAState(
|
| 95 |
+
pca_components=t["pca_components"],
|
| 96 |
+
pca_mean=t["pca_mean"],
|
| 97 |
+
latent_scale=float(t["latent_scale"].item()),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ________
|
| 102 |
+
|
| 103 |
+
def load_audio(path: str, max_duration: int = 300) -> torch.Tensor:
|
| 104 |
+
|
| 105 |
+
decoder = AudioDecoder(path)
|
| 106 |
+
sr = decoder.metadata.sample_rate
|
| 107 |
+
audio = decoder.get_samples_played_in_range(0, max_duration)
|
| 108 |
+
audio = audio.data.mean(dim=0).unsqueeze(0)
|
| 109 |
+
audio = torchaudio.functional.resample(audio, sr, 44_100)
|
| 110 |
+
audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.))
|
| 111 |
+
# is this better than clipping? should we target a specific energy level?
|
| 112 |
+
return audio
|
| 113 |
+
|
| 114 |
+
def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True, return_normalized_text: bool = False) -> torch.Tensor | Tuple[torch.Tensor, str]:
|
| 115 |
+
|
| 116 |
+
if normalize:
|
| 117 |
+
text = text.replace("β¦", "...")
|
| 118 |
+
text = text.replace('β', "'")
|
| 119 |
+
text = text.replace('β', '"')
|
| 120 |
+
text = text.replace('β', '"')
|
| 121 |
+
text = text.replace("\n", " ")
|
| 122 |
+
text = text.replace(":", ",")
|
| 123 |
+
text = text.replace(";", ",")
|
| 124 |
+
text = text.replace("β", ", ")
|
| 125 |
+
if not text.startswith("[") and not text.startswith("(") and 'S1' not in text and 'S2' not in text:
|
| 126 |
+
text = "[S1] " + text
|
| 127 |
+
|
| 128 |
+
b = list(text.encode("utf-8"))
|
| 129 |
+
if append_bos:
|
| 130 |
+
b.insert(0, 0)
|
| 131 |
+
|
| 132 |
+
if return_normalized_text:
|
| 133 |
+
return torch.tensor(b), text
|
| 134 |
+
|
| 135 |
+
return torch.tensor(b)
|
| 136 |
+
|
| 137 |
+
def get_text_input_ids_and_mask(text_arr: List[str], max_length: int | None, device: str | None = None, normalize: bool = True, return_normalized_text: bool = False, pad_to_max: bool = True) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor, torch.Tensor, List[str]]:
|
| 138 |
+
encoded_texts = [tokenizer_encode(text, normalize=normalize, return_normalized_text=True) for text in text_arr]
|
| 139 |
+
|
| 140 |
+
if max_length is None:
|
| 141 |
+
max_length = max(len(enc) for enc, _ in encoded_texts)
|
| 142 |
+
|
| 143 |
+
tokens = torch.zeros((len(text_arr), max_length), dtype=torch.int32)
|
| 144 |
+
mask = torch.zeros((len(text_arr), max_length), dtype=torch.bool)
|
| 145 |
+
|
| 146 |
+
for i, (encoded, _) in enumerate(encoded_texts):
|
| 147 |
+
length = min(len(encoded), max_length)
|
| 148 |
+
tokens[i, :length] = encoded[:length]
|
| 149 |
+
mask[i, :length] = 1
|
| 150 |
+
|
| 151 |
+
if not pad_to_max and max_length is not None:
|
| 152 |
+
tokens, mask = tokens[:, :max_length], mask[:, :max_length]
|
| 153 |
+
|
| 154 |
+
if device is not None:
|
| 155 |
+
tokens, mask = tokens.to(device), mask.to(device)
|
| 156 |
+
|
| 157 |
+
if return_normalized_text:
|
| 158 |
+
return tokens, mask, [text for _, text in encoded_texts]
|
| 159 |
+
return tokens, mask
|
| 160 |
+
|
| 161 |
+
# ________
|
| 162 |
+
|
| 163 |
+
@torch.inference_mode()
|
| 164 |
+
def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length)
|
| 166 |
+
z_q = fish_ae.encode_zq(audio).float()
|
| 167 |
+
z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T
|
| 168 |
+
z_q = z_q * pca_state.latent_scale
|
| 169 |
+
return z_q
|
| 170 |
+
|
| 171 |
+
@torch.inference_mode()
|
| 172 |
+
def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor:
|
| 173 |
+
z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean
|
| 174 |
+
return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float()
|
| 175 |
+
|
| 176 |
+
@torch.inference_mode()
|
| 177 |
+
def ae_reconstruct(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
|
| 178 |
+
assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length)
|
| 179 |
+
z_q = ae_encode(fish_ae, pca_state, audio.to(fish_ae.dtype))
|
| 180 |
+
return ae_decode(fish_ae, pca_state, z_q)
|
| 181 |
+
|
| 182 |
+
# ________
|
| 183 |
+
|
| 184 |
+
@torch.inference_mode()
|
| 185 |
+
def get_speaker_latent_and_mask(
|
| 186 |
+
fish_ae: DAC,
|
| 187 |
+
pca_state: PCAState,
|
| 188 |
+
audio: torch.Tensor, # (1, length)
|
| 189 |
+
max_speaker_latent_length: int = 6400, # pretrained max length
|
| 190 |
+
audio_chunk_size: int = 640 * 2048, # (~30 seconds, 1/10 max speaker condition size; max chunk seen in training)
|
| 191 |
+
pad_to_max: bool = False,
|
| 192 |
+
divis_by_patch_size: int | None = 4,
|
| 193 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 194 |
+
# gets speaker latent and mask from audio, computes in chunks and concatenates (similar to training setup)
|
| 195 |
+
|
| 196 |
+
AE_DOWNSAMPLE_FACTOR = 2048
|
| 197 |
+
max_audio_len_length = max_speaker_latent_length * AE_DOWNSAMPLE_FACTOR
|
| 198 |
+
|
| 199 |
+
assert audio.ndim == 2 and audio.shape[0] == 1 # (1, length)
|
| 200 |
+
audio = audio[:, :max_audio_len_length]
|
| 201 |
+
|
| 202 |
+
latent_arr = []
|
| 203 |
+
|
| 204 |
+
for i in range(0, audio.shape[1], audio_chunk_size):
|
| 205 |
+
audio_chunk = audio[:, i:i + audio_chunk_size]
|
| 206 |
+
if audio_chunk.shape[1] < audio_chunk_size:
|
| 207 |
+
audio_chunk = torch.nn.functional.pad(audio_chunk, (0, audio_chunk_size - audio_chunk.shape[1]))
|
| 208 |
+
|
| 209 |
+
latent_chunk = ae_encode(fish_ae, pca_state, audio_chunk.unsqueeze(0))
|
| 210 |
+
latent_arr.append(latent_chunk)
|
| 211 |
+
|
| 212 |
+
speaker_latent = torch.cat(latent_arr, dim=1)
|
| 213 |
+
|
| 214 |
+
actual_latent_length = audio.shape[1] // AE_DOWNSAMPLE_FACTOR
|
| 215 |
+
speaker_mask = (torch.arange(speaker_latent.shape[1], device=speaker_latent.device) < actual_latent_length).unsqueeze(0)
|
| 216 |
+
|
| 217 |
+
if pad_to_max and speaker_latent.shape[1] < max_speaker_latent_length:
|
| 218 |
+
speaker_latent = torch.nn.functional.pad(speaker_latent, (0, 0, 0, max_speaker_latent_length - speaker_latent.shape[1]))
|
| 219 |
+
speaker_mask = torch.nn.functional.pad(speaker_mask, (0, max_speaker_latent_length - speaker_mask.shape[1]))
|
| 220 |
+
elif not pad_to_max:
|
| 221 |
+
speaker_latent = speaker_latent[:, :actual_latent_length]
|
| 222 |
+
speaker_mask = speaker_mask[:, :actual_latent_length]
|
| 223 |
+
|
| 224 |
+
if divis_by_patch_size is not None:
|
| 225 |
+
speaker_latent = speaker_latent[:, :speaker_latent.shape[1] // divis_by_patch_size * divis_by_patch_size]
|
| 226 |
+
speaker_mask = speaker_mask[:, :speaker_mask.shape[1] // divis_by_patch_size * divis_by_patch_size]
|
| 227 |
+
|
| 228 |
+
return speaker_latent, speaker_mask
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ________
|
| 232 |
+
|
| 233 |
+
def find_flattening_point(data, target_value=0.0, window_size=20, std_threshold=0.05):
|
| 234 |
+
# simple heuristic to find end of latent generations; slow and can be improved
|
| 235 |
+
# (data is (length, 80))
|
| 236 |
+
padded_data = torch.cat([data, torch.zeros(window_size, *data.shape[1:], device=data.device, dtype=data.dtype)])
|
| 237 |
+
for i in range(len(padded_data) - window_size):
|
| 238 |
+
window = padded_data[i:i + window_size]
|
| 239 |
+
if window.std() < std_threshold and abs(window.mean() - target_value) < 0.1:
|
| 240 |
+
return i
|
| 241 |
+
return len(data)
|
| 242 |
+
|
| 243 |
+
def crop_audio_to_flattening_point(audio: torch.Tensor, latent: torch.Tensor) -> torch.Tensor:
|
| 244 |
+
# (audio is (..., length), latent is (length, 80))
|
| 245 |
+
flattening_point = find_flattening_point(latent)
|
| 246 |
+
return audio[..., :flattening_point * 2048]
|
| 247 |
+
|
| 248 |
+
SampleFn = Callable[
|
| 249 |
+
[EchoDiT, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int],
|
| 250 |
+
torch.Tensor
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
@torch.inference_mode()
|
| 254 |
+
def sample_pipeline(
|
| 255 |
+
model: EchoDiT,
|
| 256 |
+
fish_ae: DAC,
|
| 257 |
+
pca_state: PCAState,
|
| 258 |
+
sample_fn: SampleFn,
|
| 259 |
+
text_prompt: str,
|
| 260 |
+
speaker_audio: torch.Tensor | None,
|
| 261 |
+
rng_seed: int,
|
| 262 |
+
pad_to_max_speaker_latent_length: int | None = None,
|
| 263 |
+
pad_to_max_text_length: int | None = None,
|
| 264 |
+
normalize_text: bool = True,
|
| 265 |
+
) -> Tuple[torch.Tensor, str]:
|
| 266 |
+
|
| 267 |
+
MAX_SPEAKER_LATENT_LENGTH = 6400 # max seen during training, though maybe can go higher?
|
| 268 |
+
MAX_TEXT_LENGTH = 768
|
| 269 |
+
|
| 270 |
+
device, dtype = model.device, model.dtype
|
| 271 |
+
|
| 272 |
+
text_input_ids, text_mask, normalized_text = get_text_input_ids_and_mask([text_prompt], max_length=min(pad_to_max_text_length or MAX_TEXT_LENGTH, MAX_TEXT_LENGTH), device=device, normalize=normalize_text, return_normalized_text=True, pad_to_max=(pad_to_max_text_length is not None))
|
| 273 |
+
|
| 274 |
+
if speaker_audio is None:
|
| 275 |
+
speaker_latent = torch.zeros((1, pad_to_max_speaker_latent_length or 4, 80), device=device, dtype=dtype)
|
| 276 |
+
speaker_mask = torch.zeros((1, pad_to_max_speaker_latent_length or 4), device=device, dtype=torch.bool)
|
| 277 |
+
else:
|
| 278 |
+
speaker_latent, speaker_mask = get_speaker_latent_and_mask(
|
| 279 |
+
fish_ae,
|
| 280 |
+
pca_state,
|
| 281 |
+
speaker_audio.to(fish_ae.dtype).to(device),
|
| 282 |
+
max_speaker_latent_length=pad_to_max_speaker_latent_length or MAX_SPEAKER_LATENT_LENGTH,
|
| 283 |
+
pad_to_max=(pad_to_max_speaker_latent_length is not None)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
latent_out = sample_fn(model, speaker_latent, speaker_mask, text_input_ids, text_mask, rng_seed)
|
| 287 |
+
|
| 288 |
+
audio_out = ae_decode(fish_ae, pca_state, latent_out)
|
| 289 |
+
|
| 290 |
+
audio_out = crop_audio_to_flattening_point(audio_out, latent_out[0])
|
| 291 |
+
|
| 292 |
+
return audio_out, normalized_text[0]
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# ________
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
KVCache = List[Tuple[torch.Tensor, torch.Tensor]]
|
| 301 |
+
|
| 302 |
+
def _concat_kv_caches(*caches: KVCache) -> KVCache:
|
| 303 |
+
# helper that concatenates multiple KV caches along the batch dimension
|
| 304 |
+
num_layers = len(caches[0])
|
| 305 |
+
result = []
|
| 306 |
+
for i in range(num_layers):
|
| 307 |
+
k = torch.cat([c[i][0] for c in caches], dim=0)
|
| 308 |
+
v = torch.cat([c[i][1] for c in caches], dim=0)
|
| 309 |
+
result.append((k, v))
|
| 310 |
+
return result
|
| 311 |
+
|
| 312 |
+
def _multiply_kv_cache(cache: KVCache, scale: float, max_layers: int | None = None) -> None:
|
| 313 |
+
# helper that multiplies KV cache values in-place, for kv speaker scaling
|
| 314 |
+
num_layers = len(cache) if max_layers is None else min(max_layers, len(cache))
|
| 315 |
+
for i in range(num_layers):
|
| 316 |
+
k, v = cache[i]
|
| 317 |
+
k.mul_(scale)
|
| 318 |
+
v.mul_(scale)
|
| 319 |
+
|
| 320 |
+
def _temporal_score_rescale(
|
| 321 |
+
v_pred: torch.Tensor, x_t: torch.Tensor, t: float, rescale_k: float, rescale_sigma: float
|
| 322 |
+
) -> torch.Tensor:
|
| 323 |
+
# for https://arxiv.org/pdf/2510.01184
|
| 324 |
+
if t < 1:
|
| 325 |
+
snr = (1 - t) ** 2 / (t ** 2)
|
| 326 |
+
ratio = (snr * rescale_sigma ** 2 + 1) / (snr * rescale_sigma ** 2 / rescale_k + 1)
|
| 327 |
+
return 1 / (1 - t) * (ratio * ((1 - t) * v_pred + x_t) - x_t)
|
| 328 |
+
return v_pred
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
@torch.inference_mode()
|
| 332 |
+
def sample_euler_cfg_independent_guidances(
|
| 333 |
+
model: EchoDiT,
|
| 334 |
+
speaker_latent: torch.Tensor,
|
| 335 |
+
speaker_mask: torch.Tensor,
|
| 336 |
+
text_input_ids: torch.Tensor,
|
| 337 |
+
text_mask: torch.Tensor,
|
| 338 |
+
rng_seed: int,
|
| 339 |
+
num_steps: int,
|
| 340 |
+
cfg_scale_text: float,
|
| 341 |
+
cfg_scale_speaker: float,
|
| 342 |
+
cfg_min_t: float,
|
| 343 |
+
cfg_max_t: float,
|
| 344 |
+
truncation_factor: float | None,
|
| 345 |
+
rescale_k: float | None,
|
| 346 |
+
rescale_sigma: float | None,
|
| 347 |
+
speaker_kv_scale: float | None,
|
| 348 |
+
speaker_kv_max_layers: int | None,
|
| 349 |
+
speaker_kv_min_t: float | None,
|
| 350 |
+
sequence_length: int | None = None,
|
| 351 |
+
) -> torch.Tensor:
|
| 352 |
+
|
| 353 |
+
if sequence_length is None:
|
| 354 |
+
sequence_length = 640 # max sequence length during training
|
| 355 |
+
|
| 356 |
+
INIT_SCALE = 0.999 # so that we can apply rescale to first step
|
| 357 |
+
|
| 358 |
+
device, dtype = model.device, model.dtype
|
| 359 |
+
batch_size = text_input_ids.shape[0]
|
| 360 |
+
|
| 361 |
+
rng = torch.Generator(device=device).manual_seed(rng_seed)
|
| 362 |
+
|
| 363 |
+
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
|
| 364 |
+
|
| 365 |
+
text_mask_uncond = torch.zeros_like(text_mask)
|
| 366 |
+
speaker_mask_uncond = torch.zeros_like(speaker_mask)
|
| 367 |
+
|
| 368 |
+
kv_text_cond = model.get_kv_cache_text(text_input_ids, text_mask)
|
| 369 |
+
kv_speaker_cond = model.get_kv_cache_speaker(speaker_latent.to(dtype))
|
| 370 |
+
|
| 371 |
+
if speaker_kv_scale is not None:
|
| 372 |
+
_multiply_kv_cache(kv_speaker_cond, speaker_kv_scale, speaker_kv_max_layers)
|
| 373 |
+
|
| 374 |
+
# masks prevent decoder from attending to unconds:
|
| 375 |
+
kv_text_full = _concat_kv_caches(kv_text_cond, kv_text_cond, kv_text_cond)
|
| 376 |
+
kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
|
| 377 |
+
|
| 378 |
+
full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
|
| 379 |
+
full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
|
| 380 |
+
|
| 381 |
+
x_t = torch.randn((batch_size, sequence_length, 80), device=device, dtype=torch.float32, generator=rng)
|
| 382 |
+
if truncation_factor is not None:
|
| 383 |
+
x_t = x_t * truncation_factor
|
| 384 |
+
|
| 385 |
+
for i in range(num_steps):
|
| 386 |
+
t, t_next = t_schedule[i], t_schedule[i + 1]
|
| 387 |
+
|
| 388 |
+
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
|
| 389 |
+
|
| 390 |
+
if has_cfg:
|
| 391 |
+
v_cond, v_uncond_text, v_uncond_speaker = model(
|
| 392 |
+
x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
|
| 393 |
+
t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
|
| 394 |
+
text_mask=full_text_mask,
|
| 395 |
+
speaker_mask=full_speaker_mask,
|
| 396 |
+
kv_cache_text=kv_text_full,
|
| 397 |
+
kv_cache_speaker=kv_speaker_full,
|
| 398 |
+
).float().chunk(3, dim=0)
|
| 399 |
+
v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker) # can also use a single, joint unconditional for fewer NFE
|
| 400 |
+
else:
|
| 401 |
+
v_pred = model(
|
| 402 |
+
x=x_t.to(dtype),
|
| 403 |
+
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
|
| 404 |
+
text_mask=text_mask,
|
| 405 |
+
speaker_mask=speaker_mask,
|
| 406 |
+
kv_cache_text=kv_text_cond,
|
| 407 |
+
kv_cache_speaker=kv_speaker_cond,
|
| 408 |
+
).float()
|
| 409 |
+
|
| 410 |
+
# optional temporal score rescaling: https://arxiv.org/pdf/2510.01184
|
| 411 |
+
if rescale_k is not None and rescale_sigma is not None:
|
| 412 |
+
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
|
| 413 |
+
|
| 414 |
+
# optional kv speaker scaling
|
| 415 |
+
if speaker_kv_scale is not None and t_next < speaker_kv_min_t and t >= speaker_kv_min_t:
|
| 416 |
+
_multiply_kv_cache(kv_speaker_cond, 1. / speaker_kv_scale, speaker_kv_max_layers)
|
| 417 |
+
kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
|
| 418 |
+
|
| 419 |
+
x_t = x_t + v_pred * (t_next - t)
|
| 420 |
+
|
| 421 |
+
return x_t
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# ___________________________________________________________
|
| 426 |
+
# simple example
|
| 427 |
+
|
| 428 |
+
if __name__ == "__main__":
|
| 429 |
+
model = load_model_from_hf(delete_blockwise_modules=True)
|
| 430 |
+
fish_ae = load_fish_ae_from_hf()
|
| 431 |
+
pca_state = load_pca_state_from_hf()
|
| 432 |
+
|
| 433 |
+
speaker_audio_path = "/path/to/speaker/audio.wav"
|
| 434 |
+
speaker_audio = load_audio(speaker_audio_path).cuda()
|
| 435 |
+
speaker_latent, speaker_mask = get_speaker_latent_and_mask(fish_ae, pca_state, speaker_audio)
|
| 436 |
+
|
| 437 |
+
text = "[S1] Alright, I'm going to demo this new model called Echo TTS. Hopefully this works, I'm super excited to try this and see what it can do."
|
| 438 |
+
text_input_ids, text_mask = get_text_input_ids_and_mask([text], max_length=None, device="cuda")
|
| 439 |
+
|
| 440 |
+
latent_out = sample_euler_cfg_independent_guidances(
|
| 441 |
+
model=model,
|
| 442 |
+
speaker_latent=speaker_latent,
|
| 443 |
+
speaker_mask=speaker_mask,
|
| 444 |
+
text_input_ids=text_input_ids,
|
| 445 |
+
text_mask=text_mask,
|
| 446 |
+
rng_seed=0,
|
| 447 |
+
num_steps=40,
|
| 448 |
+
cfg_scale_text=3.0,
|
| 449 |
+
cfg_scale_speaker=8.0,
|
| 450 |
+
cfg_min_t=0.5,
|
| 451 |
+
cfg_max_t=1.0,
|
| 452 |
+
truncation_factor=0.8,
|
| 453 |
+
rescale_k=None,
|
| 454 |
+
rescale_sigma=None,
|
| 455 |
+
speaker_kv_scale=None,
|
| 456 |
+
speaker_kv_max_layers=None,
|
| 457 |
+
speaker_kv_min_t=None,
|
| 458 |
+
sequence_length=640, # (max 640. shorter lengths will generate prefixes, not necessarily full generations)
|
| 459 |
+
)
|
| 460 |
+
audio_out = ae_decode(fish_ae, pca_state, latent_out)
|
| 461 |
+
audio_out = crop_audio_to_flattening_point(audio_out, latent_out[0])
|
| 462 |
+
torchaudio.save("output.wav", audio_out[0].cpu(), 44100)
|
inference_blockwise.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from inference import (
|
| 6 |
+
KVCache,
|
| 7 |
+
_concat_kv_caches,
|
| 8 |
+
_multiply_kv_cache,
|
| 9 |
+
_temporal_score_rescale,
|
| 10 |
+
)
|
| 11 |
+
from model import EchoDiT
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@torch.inference_mode()
|
| 15 |
+
def sample_blockwise_euler_cfg_independent_guidances(
|
| 16 |
+
model: EchoDiT,
|
| 17 |
+
speaker_latent: torch.Tensor,
|
| 18 |
+
speaker_mask: torch.Tensor,
|
| 19 |
+
text_input_ids: torch.Tensor,
|
| 20 |
+
text_mask: torch.Tensor,
|
| 21 |
+
rng_seed: int,
|
| 22 |
+
block_sizes: List[int],
|
| 23 |
+
num_steps: int,
|
| 24 |
+
cfg_scale_text: float,
|
| 25 |
+
cfg_scale_speaker: float,
|
| 26 |
+
cfg_min_t: float,
|
| 27 |
+
cfg_max_t: float,
|
| 28 |
+
truncation_factor: float | None,
|
| 29 |
+
rescale_k: float | None,
|
| 30 |
+
rescale_sigma: float | None,
|
| 31 |
+
speaker_kv_scale: float | None,
|
| 32 |
+
speaker_kv_max_layers: int | None,
|
| 33 |
+
speaker_kv_min_t: float | None,
|
| 34 |
+
continuation_latent: torch.Tensor | None = None,
|
| 35 |
+
) -> torch.Tensor:
|
| 36 |
+
|
| 37 |
+
INIT_SCALE = 0.999 # so that we can apply rescale to first step
|
| 38 |
+
|
| 39 |
+
device, dtype = model.device, model.dtype
|
| 40 |
+
batch_size = text_input_ids.shape[0]
|
| 41 |
+
|
| 42 |
+
rng = torch.Generator(device=device).manual_seed(rng_seed)
|
| 43 |
+
|
| 44 |
+
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
|
| 45 |
+
|
| 46 |
+
text_mask_uncond = torch.zeros_like(text_mask)
|
| 47 |
+
speaker_mask_uncond = torch.zeros_like(speaker_mask)
|
| 48 |
+
|
| 49 |
+
kv_text_cond = model.get_kv_cache_text(text_input_ids, text_mask)
|
| 50 |
+
kv_speaker_cond = model.get_kv_cache_speaker(speaker_latent.to(dtype))
|
| 51 |
+
|
| 52 |
+
# masks prevent decoder from attending to unconds:
|
| 53 |
+
kv_text_full = _concat_kv_caches(kv_text_cond, kv_text_cond, kv_text_cond)
|
| 54 |
+
kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
|
| 55 |
+
|
| 56 |
+
full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
|
| 57 |
+
full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
|
| 58 |
+
|
| 59 |
+
prefix_latent = torch.zeros((batch_size, sum(block_sizes) , 80), device=device, dtype=torch.float32)
|
| 60 |
+
|
| 61 |
+
start_pos = 0
|
| 62 |
+
if continuation_latent is not None:
|
| 63 |
+
continuation_len = continuation_latent.shape[1]
|
| 64 |
+
prefix_latent = torch.cat([continuation_latent, prefix_latent], dim=1)
|
| 65 |
+
start_pos = continuation_len
|
| 66 |
+
|
| 67 |
+
for block_size in block_sizes:
|
| 68 |
+
if speaker_kv_scale is not None:
|
| 69 |
+
_multiply_kv_cache(kv_speaker_cond, speaker_kv_scale, speaker_kv_max_layers)
|
| 70 |
+
kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
|
| 71 |
+
|
| 72 |
+
full_prefix_latent = torch.cat([prefix_latent, prefix_latent, prefix_latent], dim=0)
|
| 73 |
+
kv_latent_full = model.get_kv_cache_latent(full_prefix_latent.to(dtype))
|
| 74 |
+
kv_latent_cond = [(k[:batch_size], v[:batch_size]) for k, v in kv_latent_full]
|
| 75 |
+
|
| 76 |
+
x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32, generator=rng)
|
| 77 |
+
if truncation_factor is not None:
|
| 78 |
+
x_t = x_t * truncation_factor
|
| 79 |
+
|
| 80 |
+
for i in range(num_steps):
|
| 81 |
+
t, t_next = t_schedule[i], t_schedule[i + 1]
|
| 82 |
+
|
| 83 |
+
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
|
| 84 |
+
|
| 85 |
+
if has_cfg:
|
| 86 |
+
v_cond, v_uncond_text, v_uncond_speaker = model(
|
| 87 |
+
x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
|
| 88 |
+
t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
|
| 89 |
+
text_mask=full_text_mask,
|
| 90 |
+
speaker_mask=full_speaker_mask,
|
| 91 |
+
start_pos=start_pos,
|
| 92 |
+
kv_cache_text=kv_text_full,
|
| 93 |
+
kv_cache_speaker=kv_speaker_full,
|
| 94 |
+
kv_cache_latent=kv_latent_full,
|
| 95 |
+
).float().chunk(3, dim=0)
|
| 96 |
+
v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker)
|
| 97 |
+
else:
|
| 98 |
+
v_pred = model(
|
| 99 |
+
x=x_t.to(dtype),
|
| 100 |
+
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
|
| 101 |
+
text_mask=text_mask,
|
| 102 |
+
speaker_mask=speaker_mask,
|
| 103 |
+
start_pos=start_pos,
|
| 104 |
+
kv_cache_text=kv_text_cond,
|
| 105 |
+
kv_cache_speaker=kv_speaker_cond,
|
| 106 |
+
kv_cache_latent=kv_latent_cond,
|
| 107 |
+
).float()
|
| 108 |
+
|
| 109 |
+
# optional temporal score rescaling: https://arxiv.org/pdf/2510.01184
|
| 110 |
+
if rescale_k is not None and rescale_sigma is not None:
|
| 111 |
+
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
|
| 112 |
+
|
| 113 |
+
# optional kv speaker scaling
|
| 114 |
+
if speaker_kv_scale is not None and t_next < speaker_kv_min_t and t >= speaker_kv_min_t:
|
| 115 |
+
_multiply_kv_cache(kv_speaker_cond, 1. / speaker_kv_scale, speaker_kv_max_layers)
|
| 116 |
+
kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
|
| 117 |
+
|
| 118 |
+
x_t = x_t + v_pred * (t_next - t)
|
| 119 |
+
|
| 120 |
+
prefix_latent[:, start_pos:start_pos + block_size] = x_t
|
| 121 |
+
start_pos += block_size
|
| 122 |
+
|
| 123 |
+
return prefix_latent
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
import torchaudio
|
| 128 |
+
from inference import (
|
| 129 |
+
load_model_from_hf,
|
| 130 |
+
load_fish_ae_from_hf,
|
| 131 |
+
load_pca_state_from_hf,
|
| 132 |
+
load_audio,
|
| 133 |
+
get_text_input_ids_and_mask,
|
| 134 |
+
get_speaker_latent_and_mask,
|
| 135 |
+
ae_encode,
|
| 136 |
+
ae_decode,
|
| 137 |
+
crop_audio_to_flattening_point,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
model = load_model_from_hf()
|
| 141 |
+
fish_ae = load_fish_ae_from_hf()
|
| 142 |
+
pca_state = load_pca_state_from_hf()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# example 1, generate 320 in three blocks
|
| 146 |
+
|
| 147 |
+
speaker_audio_path = "/path/to/speaker/audio.wav"
|
| 148 |
+
speaker_audio = load_audio(speaker_audio_path).cuda()
|
| 149 |
+
speaker_latent, speaker_mask = get_speaker_latent_and_mask(fish_ae, pca_state, speaker_audio)
|
| 150 |
+
|
| 151 |
+
text = "[S1] Alright, I'm going to demo this new model called Echo TTS."
|
| 152 |
+
text_input_ids, text_mask = get_text_input_ids_and_mask([text], max_length=None, device="cuda")
|
| 153 |
+
|
| 154 |
+
latent_out = sample_blockwise_euler_cfg_independent_guidances(
|
| 155 |
+
model=model,
|
| 156 |
+
speaker_latent=speaker_latent,
|
| 157 |
+
speaker_mask=speaker_mask,
|
| 158 |
+
text_input_ids=text_input_ids,
|
| 159 |
+
text_mask=text_mask,
|
| 160 |
+
rng_seed=0,
|
| 161 |
+
block_sizes=[128, 128, 64], # (sums to 320, so will be ~15 seconds; supports up to 640)
|
| 162 |
+
num_steps=40,
|
| 163 |
+
cfg_scale_text=3.0,
|
| 164 |
+
cfg_scale_speaker=5.0,
|
| 165 |
+
cfg_min_t=0.5,
|
| 166 |
+
cfg_max_t=1.0,
|
| 167 |
+
truncation_factor=0.8,
|
| 168 |
+
rescale_k=None,
|
| 169 |
+
rescale_sigma=None,
|
| 170 |
+
speaker_kv_scale=None,
|
| 171 |
+
speaker_kv_max_layers=None,
|
| 172 |
+
speaker_kv_min_t=None,
|
| 173 |
+
)
|
| 174 |
+
audio_out = ae_decode(fish_ae, pca_state, latent_out)
|
| 175 |
+
audio_out = crop_audio_to_flattening_point(audio_out, latent_out[0])
|
| 176 |
+
torchaudio.save("output_blockwise.wav", audio_out[0].cpu(), 44100)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ___________________________________________________________
|
| 181 |
+
# example 2: with continuation latent (use same speaker audio as first example, generate from partial output of first example)
|
| 182 |
+
|
| 183 |
+
continuation_audio_path = "output_blockwise.wav" # can be any path
|
| 184 |
+
continuation_audio = load_audio(continuation_audio_path).cuda()
|
| 185 |
+
continuation_latent, continuation_mask = get_speaker_latent_and_mask(fish_ae, pca_state, continuation_audio)
|
| 186 |
+
|
| 187 |
+
continuation_latent = continuation_latent[:, :continuation_mask.sum()]
|
| 188 |
+
|
| 189 |
+
text = "[S1] Alright, I'm going to demo this new model called Echo TTS, and now, we're going to continue from the audio we already generated and add some more text."
|
| 190 |
+
# NOTE this MUST include the text from the continuation prefix. can use https://huggingface.co/jordand/whisper-d-v1a to get in-distribution transcription automatically.
|
| 191 |
+
|
| 192 |
+
text_input_ids, text_mask = get_text_input_ids_and_mask([text], max_length=None, device="cuda")
|
| 193 |
+
|
| 194 |
+
continuation_block_sizes = [256] # (generate up to 12 more seconds)
|
| 195 |
+
# NOTE: these do not include the continuation latent length, so sum(block_sizes) + continuation_latent.shape[1] should be < 640 (to be in-distribution with training data)
|
| 196 |
+
|
| 197 |
+
latent_out_continued = sample_blockwise_euler_cfg_independent_guidances(
|
| 198 |
+
model=model,
|
| 199 |
+
speaker_latent=speaker_latent,
|
| 200 |
+
speaker_mask=speaker_mask,
|
| 201 |
+
text_input_ids=text_input_ids,
|
| 202 |
+
text_mask=text_mask,
|
| 203 |
+
rng_seed=0,
|
| 204 |
+
block_sizes=continuation_block_sizes,
|
| 205 |
+
num_steps=40,
|
| 206 |
+
cfg_scale_text=3.0,
|
| 207 |
+
cfg_scale_speaker=3.0,
|
| 208 |
+
cfg_min_t=0.5,
|
| 209 |
+
cfg_max_t=1.0,
|
| 210 |
+
truncation_factor=0.8,
|
| 211 |
+
rescale_k=None,
|
| 212 |
+
rescale_sigma=None,
|
| 213 |
+
speaker_kv_scale=None,
|
| 214 |
+
speaker_kv_max_layers=None,
|
| 215 |
+
speaker_kv_min_t=None,
|
| 216 |
+
continuation_latent=continuation_latent,
|
| 217 |
+
)
|
| 218 |
+
audio_out_continued = ae_decode(fish_ae, pca_state, latent_out_continued)
|
| 219 |
+
audio_out_continued = crop_audio_to_flattening_point(audio_out_continued, latent_out_continued[0])
|
| 220 |
+
torchaudio.save("output_blockwise_continued.wav", audio_out_continued[0].cpu(), 44100)
|
model.py
ADDED
|
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
|
| 10 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)] / dim))
|
| 11 |
+
t = torch.arange(end)
|
| 12 |
+
freqs = torch.outer(t, freqs)
|
| 13 |
+
freqs_cis = torch.complex(torch.cos(freqs), torch.sin(freqs))
|
| 14 |
+
return freqs_cis
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def apply_rotary_emb(
|
| 18 |
+
x: torch.Tensor,
|
| 19 |
+
freqs_cis: torch.Tensor,
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:3], -1, 2))
|
| 22 |
+
x_ = x_ * freqs_cis[..., None, :]
|
| 23 |
+
x_ = torch.view_as_real(x_).reshape(x.shape)
|
| 24 |
+
return x_.type_as(x)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_timestep_embedding(
|
| 28 |
+
timestep: torch.Tensor,
|
| 29 |
+
embed_size: int,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
assert embed_size % 2 == 0
|
| 32 |
+
|
| 33 |
+
half = embed_size // 2
|
| 34 |
+
|
| 35 |
+
freqs = 1000 * torch.exp(
|
| 36 |
+
-torch.log(torch.tensor(10000.0)) *
|
| 37 |
+
torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 38 |
+
).to(timestep.device)
|
| 39 |
+
|
| 40 |
+
args = timestep[..., None] * freqs[None]
|
| 41 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 42 |
+
|
| 43 |
+
return embedding.to(timestep.dtype)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LowRankAdaLN(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
model_size: int,
|
| 50 |
+
rank: int,
|
| 51 |
+
eps: float
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.eps = eps
|
| 55 |
+
|
| 56 |
+
self.shift_down = nn.Linear(model_size, rank, bias=False)
|
| 57 |
+
self.scale_down = nn.Linear(model_size, rank, bias=False)
|
| 58 |
+
self.gate_down = nn.Linear(model_size, rank, bias=False)
|
| 59 |
+
|
| 60 |
+
self.shift_up = nn.Linear(rank, model_size, bias=True)
|
| 61 |
+
self.scale_up = nn.Linear(rank, model_size, bias=True)
|
| 62 |
+
self.gate_up = nn.Linear(rank, model_size, bias=True)
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
x: torch.Tensor,
|
| 67 |
+
cond_embed: torch.Tensor,
|
| 68 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 69 |
+
|
| 70 |
+
shift, scale, gate = cond_embed.chunk(3, dim=-1)
|
| 71 |
+
|
| 72 |
+
shift = self.shift_up(self.shift_down(F.silu(shift))) + shift
|
| 73 |
+
scale = self.scale_up(self.scale_down(F.silu(scale))) + scale
|
| 74 |
+
gate = self.gate_up(self.gate_down(F.silu(gate))) + gate
|
| 75 |
+
|
| 76 |
+
x_dtype = x.dtype
|
| 77 |
+
x = x.float()
|
| 78 |
+
x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps)
|
| 79 |
+
x = x * (scale + 1) + shift
|
| 80 |
+
|
| 81 |
+
gate = torch.tanh(gate)
|
| 82 |
+
|
| 83 |
+
return x.to(x_dtype), gate
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class RMSNorm(nn.Module): # could also just use torch rmsnorm
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
model_size: int | Tuple[int, int],
|
| 90 |
+
eps: float
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.eps = eps
|
| 94 |
+
|
| 95 |
+
if isinstance(model_size, int):
|
| 96 |
+
model_size = (model_size, )
|
| 97 |
+
self.weight = nn.Parameter(torch.ones(model_size))
|
| 98 |
+
|
| 99 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
x_dtype = x.dtype
|
| 101 |
+
x = x.float()
|
| 102 |
+
x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps)
|
| 103 |
+
x = x * self.weight
|
| 104 |
+
return x.to(x_dtype)
|
| 105 |
+
|
| 106 |
+
class SelfAttention(nn.Module):
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
model_size: int,
|
| 110 |
+
num_heads: int,
|
| 111 |
+
is_causal: bool,
|
| 112 |
+
norm_eps: float
|
| 113 |
+
):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.num_heads = num_heads
|
| 116 |
+
self.is_causal = is_causal
|
| 117 |
+
|
| 118 |
+
self.wq = nn.Linear(model_size, model_size, bias=False)
|
| 119 |
+
self.wk = nn.Linear(model_size, model_size, bias=False)
|
| 120 |
+
self.wv = nn.Linear(model_size, model_size, bias=False)
|
| 121 |
+
self.wo = nn.Linear(model_size, model_size, bias=False)
|
| 122 |
+
self.gate = nn.Linear(model_size, model_size, bias=False)
|
| 123 |
+
|
| 124 |
+
assert model_size % num_heads == 0
|
| 125 |
+
self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
|
| 126 |
+
self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
|
| 127 |
+
|
| 128 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 129 |
+
|
| 130 |
+
batch_size, seq_len = x.shape[:2]
|
| 131 |
+
|
| 132 |
+
xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 133 |
+
xk = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 134 |
+
xv = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 135 |
+
|
| 136 |
+
gate = self.gate(x)
|
| 137 |
+
|
| 138 |
+
xq = self.q_norm(xq)
|
| 139 |
+
xk = self.k_norm(xk)
|
| 140 |
+
|
| 141 |
+
xq = apply_rotary_emb(xq, freqs_cis[:seq_len])
|
| 142 |
+
xk = apply_rotary_emb(xk, freqs_cis[:seq_len])
|
| 143 |
+
|
| 144 |
+
if mask is not None:
|
| 145 |
+
assert mask.ndim == 2 # (b, s)
|
| 146 |
+
mask = mask[:, None, None]
|
| 147 |
+
|
| 148 |
+
output = F.scaled_dot_product_attention(
|
| 149 |
+
query=xq.transpose(1, 2),
|
| 150 |
+
key=xk.transpose(1, 2),
|
| 151 |
+
value=xv.transpose(1, 2),
|
| 152 |
+
attn_mask=mask,
|
| 153 |
+
is_causal=self.is_causal
|
| 154 |
+
).transpose(1, 2)
|
| 155 |
+
|
| 156 |
+
output = output.reshape(batch_size, seq_len, -1)
|
| 157 |
+
output = output * torch.sigmoid(gate)
|
| 158 |
+
|
| 159 |
+
output = self.wo(output)
|
| 160 |
+
|
| 161 |
+
return output
|
| 162 |
+
|
| 163 |
+
class JointAttention(nn.Module):
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
model_size: int,
|
| 167 |
+
num_heads: int,
|
| 168 |
+
text_model_size: int,
|
| 169 |
+
speaker_model_size: int,
|
| 170 |
+
speaker_patch_size: int,
|
| 171 |
+
norm_eps: float
|
| 172 |
+
):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.speaker_patch_size = speaker_patch_size
|
| 175 |
+
self.num_heads = num_heads
|
| 176 |
+
|
| 177 |
+
self.wq = nn.Linear(model_size, model_size, bias=False)
|
| 178 |
+
self.wk = nn.Linear(model_size, model_size, bias=False)
|
| 179 |
+
self.wv = nn.Linear(model_size, model_size, bias=False)
|
| 180 |
+
|
| 181 |
+
self.wk_text = nn.Linear(text_model_size, model_size, bias=False)
|
| 182 |
+
self.wv_text = nn.Linear(text_model_size, model_size, bias=False)
|
| 183 |
+
|
| 184 |
+
self.wk_speaker = nn.Linear(speaker_model_size, model_size, bias=False)
|
| 185 |
+
self.wv_speaker = nn.Linear(speaker_model_size, model_size, bias=False)
|
| 186 |
+
|
| 187 |
+
self.wk_latent = nn.Linear(speaker_model_size, model_size, bias=False)
|
| 188 |
+
self.wv_latent = nn.Linear(speaker_model_size, model_size, bias=False)
|
| 189 |
+
|
| 190 |
+
assert model_size % num_heads == 0
|
| 191 |
+
self.head_dim = model_size // num_heads
|
| 192 |
+
self.q_norm = RMSNorm((num_heads, self.head_dim), eps=norm_eps)
|
| 193 |
+
self.k_norm = RMSNorm((num_heads, self.head_dim), eps=norm_eps)
|
| 194 |
+
|
| 195 |
+
self.gate = nn.Linear(model_size, model_size, bias=False)
|
| 196 |
+
|
| 197 |
+
self.wo = nn.Linear(model_size, model_size, bias=False)
|
| 198 |
+
|
| 199 |
+
def _apply_rotary_half(self, y: torch.Tensor, fc: torch.Tensor) -> torch.Tensor:
|
| 200 |
+
y1, y2 = y.chunk(2, dim=-2)
|
| 201 |
+
y1 = apply_rotary_emb(y1, fc)
|
| 202 |
+
return torch.cat([y1, y2], dim=-2)
|
| 203 |
+
|
| 204 |
+
def forward(
|
| 205 |
+
self,
|
| 206 |
+
x: torch.Tensor,
|
| 207 |
+
text_mask: torch.Tensor,
|
| 208 |
+
speaker_mask: torch.Tensor,
|
| 209 |
+
freqs_cis: torch.Tensor,
|
| 210 |
+
kv_cache_text: Tuple[torch.Tensor, torch.Tensor],
|
| 211 |
+
kv_cache_speaker: Tuple[torch.Tensor, torch.Tensor],
|
| 212 |
+
start_pos: int | None,
|
| 213 |
+
kv_cache_latent: Tuple[torch.Tensor, torch.Tensor] | None
|
| 214 |
+
) -> torch.Tensor:
|
| 215 |
+
batch_size, seq_len = x.shape[:2]
|
| 216 |
+
|
| 217 |
+
xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 218 |
+
xk_self = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 219 |
+
xv_self = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 220 |
+
|
| 221 |
+
xq = self.q_norm(xq)
|
| 222 |
+
xk_self = self.k_norm(xk_self)
|
| 223 |
+
|
| 224 |
+
gate = self.gate(x)
|
| 225 |
+
|
| 226 |
+
if start_pos is None:
|
| 227 |
+
start_pos = 0
|
| 228 |
+
|
| 229 |
+
freqs_q = freqs_cis[start_pos : start_pos + seq_len]
|
| 230 |
+
|
| 231 |
+
xq = self._apply_rotary_half(xq, freqs_q)
|
| 232 |
+
xk_self = self._apply_rotary_half(xk_self, freqs_q)
|
| 233 |
+
|
| 234 |
+
xk_text, xv_text = kv_cache_text
|
| 235 |
+
xk_speaker, xv_speaker = kv_cache_speaker
|
| 236 |
+
|
| 237 |
+
if kv_cache_latent is None or kv_cache_latent[0].shape [1] == 0:
|
| 238 |
+
xk_latent = torch.zeros((batch_size, 0, self.num_heads, xq.shape[-1]), device=x.device, dtype=x.dtype)
|
| 239 |
+
xv_latent = torch.zeros((batch_size, 0, self.num_heads, xq.shape[-1]), device=x.device, dtype=x.dtype)
|
| 240 |
+
latent_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=x.device)
|
| 241 |
+
else:
|
| 242 |
+
xk_latent, xv_latent = kv_cache_latent
|
| 243 |
+
latent_positions = torch.arange(xk_latent.shape[1], device=x.device, dtype=torch.long) * self.speaker_patch_size
|
| 244 |
+
latent_mask = (latent_positions[None, :] < start_pos).expand(batch_size, xk_latent.shape[1])
|
| 245 |
+
|
| 246 |
+
xk = torch.cat([xk_self, xk_latent, xk_text, xk_speaker], dim=1)
|
| 247 |
+
xv = torch.cat([xv_self, xv_latent, xv_text, xv_speaker], dim=1)
|
| 248 |
+
|
| 249 |
+
self_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=x.device)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
mask = torch.cat([self_mask, latent_mask, text_mask, speaker_mask], dim=1)
|
| 253 |
+
mask = mask[:, None, None]
|
| 254 |
+
|
| 255 |
+
output = F.scaled_dot_product_attention(
|
| 256 |
+
query=xq.transpose(1, 2),
|
| 257 |
+
key=xk.transpose(1, 2),
|
| 258 |
+
value=xv.transpose(1, 2),
|
| 259 |
+
attn_mask=mask,
|
| 260 |
+
is_causal=False
|
| 261 |
+
).transpose(1, 2)
|
| 262 |
+
|
| 263 |
+
output = output.reshape(batch_size, seq_len, -1)
|
| 264 |
+
output = output * torch.sigmoid(gate)
|
| 265 |
+
|
| 266 |
+
output = self.wo(output)
|
| 267 |
+
|
| 268 |
+
return output
|
| 269 |
+
|
| 270 |
+
def get_kv_cache_text(self, text_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 271 |
+
batch_size = text_state.shape[0]
|
| 272 |
+
xk = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
|
| 273 |
+
xv = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
|
| 274 |
+
xk = self.k_norm(xk)
|
| 275 |
+
return xk, xv
|
| 276 |
+
|
| 277 |
+
def get_kv_cache_speaker(self, speaker_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 278 |
+
batch_size = speaker_state.shape[0]
|
| 279 |
+
xk = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
|
| 280 |
+
xv = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
|
| 281 |
+
xk = self.k_norm(xk)
|
| 282 |
+
return xk, xv
|
| 283 |
+
|
| 284 |
+
def get_kv_cache_latent(self, latent_state: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 285 |
+
batch_size = latent_state.shape[0]
|
| 286 |
+
seq_len = latent_state.shape[1]
|
| 287 |
+
xk = self.wk_latent(latent_state).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 288 |
+
xv = self.wv_latent(latent_state).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 289 |
+
xk = self.k_norm(xk)
|
| 290 |
+
|
| 291 |
+
xk = self._apply_rotary_half(xk, freqs_cis)
|
| 292 |
+
|
| 293 |
+
return xk, xv
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class MLP(nn.Module):
|
| 297 |
+
def __init__(
|
| 298 |
+
self,
|
| 299 |
+
model_size: int,
|
| 300 |
+
intermediate_size: int
|
| 301 |
+
):
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.w1 = nn.Linear(model_size, intermediate_size, bias=False)
|
| 304 |
+
self.w3 = nn.Linear(model_size, intermediate_size, bias=False)
|
| 305 |
+
self.w2 = nn.Linear(intermediate_size, model_size, bias=False)
|
| 306 |
+
|
| 307 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 308 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class EncoderTransformerBlock(nn.Module):
|
| 312 |
+
def __init__(
|
| 313 |
+
self,
|
| 314 |
+
model_size: int,
|
| 315 |
+
num_heads: int,
|
| 316 |
+
intermediate_size: int,
|
| 317 |
+
is_causal: bool,
|
| 318 |
+
norm_eps: float
|
| 319 |
+
):
|
| 320 |
+
super().__init__()
|
| 321 |
+
self.attention = SelfAttention(
|
| 322 |
+
model_size=model_size,
|
| 323 |
+
num_heads=num_heads,
|
| 324 |
+
is_causal=is_causal,
|
| 325 |
+
norm_eps=norm_eps
|
| 326 |
+
)
|
| 327 |
+
self.mlp = MLP(
|
| 328 |
+
model_size=model_size,
|
| 329 |
+
intermediate_size=intermediate_size
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
self.attention_norm = RMSNorm(model_size, norm_eps)
|
| 333 |
+
self.mlp_norm = RMSNorm(model_size, norm_eps)
|
| 334 |
+
|
| 335 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 336 |
+
x = x + self.attention(self.attention_norm(x), mask, freqs_cis)
|
| 337 |
+
x = x + self.mlp(self.mlp_norm(x))
|
| 338 |
+
|
| 339 |
+
return x
|
| 340 |
+
|
| 341 |
+
class TransformerBlock(nn.Module):
|
| 342 |
+
def __init__(
|
| 343 |
+
self,
|
| 344 |
+
model_size: int,
|
| 345 |
+
num_heads: int,
|
| 346 |
+
intermediate_size: int,
|
| 347 |
+
norm_eps: float,
|
| 348 |
+
text_model_size: int,
|
| 349 |
+
speaker_model_size: int,
|
| 350 |
+
speaker_patch_size: int,
|
| 351 |
+
adaln_rank: int,
|
| 352 |
+
):
|
| 353 |
+
super().__init__()
|
| 354 |
+
self.attention = JointAttention(
|
| 355 |
+
model_size=model_size,
|
| 356 |
+
num_heads=num_heads,
|
| 357 |
+
text_model_size=text_model_size,
|
| 358 |
+
speaker_model_size=speaker_model_size,
|
| 359 |
+
speaker_patch_size=speaker_patch_size,
|
| 360 |
+
norm_eps=norm_eps
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
self.mlp = MLP(
|
| 364 |
+
model_size=model_size,
|
| 365 |
+
intermediate_size=intermediate_size
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
self.attention_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps)
|
| 369 |
+
self.mlp_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps)
|
| 370 |
+
|
| 371 |
+
def forward(
|
| 372 |
+
self,
|
| 373 |
+
x: torch.Tensor,
|
| 374 |
+
cond_embed: torch.Tensor,
|
| 375 |
+
text_mask: torch.Tensor,
|
| 376 |
+
speaker_mask: torch.Tensor,
|
| 377 |
+
freqs_cis: torch.Tensor,
|
| 378 |
+
kv_cache_text: Tuple[torch.Tensor, torch.Tensor],
|
| 379 |
+
kv_cache_speaker: Tuple[torch.Tensor, torch.Tensor],
|
| 380 |
+
start_pos: int | None,
|
| 381 |
+
kv_cache_latent: Tuple[torch.Tensor, torch.Tensor] | None,
|
| 382 |
+
) -> torch.Tensor:
|
| 383 |
+
|
| 384 |
+
x_norm, attention_gate = self.attention_adaln(x, cond_embed)
|
| 385 |
+
x = x + attention_gate * self.attention(x_norm, text_mask, speaker_mask, freqs_cis, kv_cache_text, kv_cache_speaker, start_pos, kv_cache_latent)
|
| 386 |
+
|
| 387 |
+
x_norm, mlp_gate = self.mlp_adaln(x, cond_embed)
|
| 388 |
+
x = x + mlp_gate * self.mlp(x_norm)
|
| 389 |
+
|
| 390 |
+
return x
|
| 391 |
+
|
| 392 |
+
class TextEncoder(nn.Module):
|
| 393 |
+
def __init__(
|
| 394 |
+
self,
|
| 395 |
+
vocab_size: int,
|
| 396 |
+
model_size: int,
|
| 397 |
+
num_layers: int,
|
| 398 |
+
num_heads: int,
|
| 399 |
+
intermediate_size: int,
|
| 400 |
+
norm_eps: float,
|
| 401 |
+
):
|
| 402 |
+
super().__init__()
|
| 403 |
+
self.text_embedding = nn.Embedding(vocab_size, model_size)
|
| 404 |
+
|
| 405 |
+
self.blocks = nn.ModuleList()
|
| 406 |
+
for i in range(num_layers):
|
| 407 |
+
block = EncoderTransformerBlock(
|
| 408 |
+
model_size=model_size,
|
| 409 |
+
num_heads=num_heads,
|
| 410 |
+
intermediate_size=intermediate_size,
|
| 411 |
+
is_causal=False,
|
| 412 |
+
norm_eps=norm_eps
|
| 413 |
+
)
|
| 414 |
+
self.blocks.append(block)
|
| 415 |
+
|
| 416 |
+
self.head_dim = model_size // num_heads
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def forward(self, input_ids: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 420 |
+
x = self.text_embedding(input_ids)
|
| 421 |
+
|
| 422 |
+
freqs_cis = precompute_freqs_cis(self.head_dim, input_ids.shape[1]).to(x.device) # could cache
|
| 423 |
+
|
| 424 |
+
for block in self.blocks:
|
| 425 |
+
x = block(x, mask, freqs_cis)
|
| 426 |
+
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
class SpeakerEncoder(nn.Module):
|
| 430 |
+
def __init__(
|
| 431 |
+
self,
|
| 432 |
+
latent_size: int,
|
| 433 |
+
patch_size: int,
|
| 434 |
+
model_size: int,
|
| 435 |
+
num_layers: int,
|
| 436 |
+
num_heads: int,
|
| 437 |
+
intermediate_size: int,
|
| 438 |
+
norm_eps: float,
|
| 439 |
+
):
|
| 440 |
+
super().__init__()
|
| 441 |
+
self.patch_size = patch_size
|
| 442 |
+
|
| 443 |
+
self.in_proj = nn.Linear(latent_size * patch_size, model_size, bias=True)
|
| 444 |
+
|
| 445 |
+
self.blocks = nn.ModuleList()
|
| 446 |
+
for i in range(num_layers):
|
| 447 |
+
block = EncoderTransformerBlock(
|
| 448 |
+
model_size=model_size,
|
| 449 |
+
num_heads=num_heads,
|
| 450 |
+
intermediate_size=intermediate_size,
|
| 451 |
+
is_causal=True,
|
| 452 |
+
norm_eps=norm_eps
|
| 453 |
+
)
|
| 454 |
+
self.blocks.append(block)
|
| 455 |
+
|
| 456 |
+
self.head_dim = model_size // num_heads
|
| 457 |
+
|
| 458 |
+
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
| 459 |
+
x = latent.reshape(*latent.shape[:-2], latent.shape[-2] // self.patch_size, latent.shape[-1] * self.patch_size)
|
| 460 |
+
|
| 461 |
+
x = self.in_proj(x)
|
| 462 |
+
x = x / 6. # this helped with initial activation dynamics in early ablations, could also bake into in_proj
|
| 463 |
+
|
| 464 |
+
freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device) # could cache
|
| 465 |
+
|
| 466 |
+
for block in self.blocks:
|
| 467 |
+
x = block(x, None, freqs_cis)
|
| 468 |
+
|
| 469 |
+
return x
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class EchoDiT(nn.Module):
|
| 473 |
+
def __init__(
|
| 474 |
+
self,
|
| 475 |
+
latent_size: int,
|
| 476 |
+
#
|
| 477 |
+
model_size: int,
|
| 478 |
+
num_layers: int,
|
| 479 |
+
num_heads: int,
|
| 480 |
+
intermediate_size: int,
|
| 481 |
+
norm_eps: float,
|
| 482 |
+
#
|
| 483 |
+
text_vocab_size: int,
|
| 484 |
+
text_model_size: int,
|
| 485 |
+
text_num_layers: int,
|
| 486 |
+
text_num_heads: int,
|
| 487 |
+
text_intermediate_size: int,
|
| 488 |
+
#
|
| 489 |
+
speaker_patch_size: int,
|
| 490 |
+
speaker_model_size: int,
|
| 491 |
+
speaker_num_layers: int,
|
| 492 |
+
speaker_num_heads: int,
|
| 493 |
+
speaker_intermediate_size: int,
|
| 494 |
+
#
|
| 495 |
+
timestep_embed_size: int,
|
| 496 |
+
adaln_rank: int,
|
| 497 |
+
):
|
| 498 |
+
super().__init__()
|
| 499 |
+
self.speaker_patch_size = speaker_patch_size
|
| 500 |
+
self.timestep_embed_size = timestep_embed_size
|
| 501 |
+
|
| 502 |
+
self.text_encoder = TextEncoder(
|
| 503 |
+
vocab_size=text_vocab_size,
|
| 504 |
+
model_size=text_model_size,
|
| 505 |
+
num_layers=text_num_layers,
|
| 506 |
+
num_heads=text_num_heads,
|
| 507 |
+
intermediate_size=text_intermediate_size,
|
| 508 |
+
norm_eps=norm_eps,
|
| 509 |
+
)
|
| 510 |
+
self.speaker_encoder = SpeakerEncoder(
|
| 511 |
+
latent_size=latent_size,
|
| 512 |
+
patch_size=speaker_patch_size,
|
| 513 |
+
model_size=speaker_model_size,
|
| 514 |
+
num_layers=speaker_num_layers,
|
| 515 |
+
num_heads=speaker_num_heads,
|
| 516 |
+
intermediate_size=speaker_intermediate_size,
|
| 517 |
+
norm_eps=norm_eps,
|
| 518 |
+
)
|
| 519 |
+
self.latent_encoder = SpeakerEncoder(
|
| 520 |
+
latent_size=latent_size,
|
| 521 |
+
patch_size=speaker_patch_size,
|
| 522 |
+
model_size=speaker_model_size,
|
| 523 |
+
num_layers=speaker_num_layers,
|
| 524 |
+
num_heads=speaker_num_heads,
|
| 525 |
+
intermediate_size=speaker_intermediate_size,
|
| 526 |
+
norm_eps=norm_eps,
|
| 527 |
+
)
|
| 528 |
+
self.text_norm = RMSNorm(text_model_size, norm_eps)
|
| 529 |
+
self.speaker_norm = RMSNorm(speaker_model_size, norm_eps)
|
| 530 |
+
self.latent_norm = RMSNorm(speaker_model_size, norm_eps)
|
| 531 |
+
|
| 532 |
+
self.cond_module = nn.Sequential(
|
| 533 |
+
nn.Linear(timestep_embed_size, model_size, bias=False),
|
| 534 |
+
nn.SiLU(),
|
| 535 |
+
nn.Linear(model_size, model_size, bias=False),
|
| 536 |
+
nn.SiLU(),
|
| 537 |
+
nn.Linear(model_size, model_size * 3, bias=False),
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
self.in_proj = nn.Linear(latent_size, model_size, bias=True)
|
| 541 |
+
|
| 542 |
+
self.blocks = nn.ModuleList()
|
| 543 |
+
for i in range(num_layers):
|
| 544 |
+
block = TransformerBlock(
|
| 545 |
+
model_size=model_size,
|
| 546 |
+
num_heads=num_heads,
|
| 547 |
+
intermediate_size=intermediate_size,
|
| 548 |
+
norm_eps=norm_eps,
|
| 549 |
+
text_model_size=text_model_size,
|
| 550 |
+
speaker_model_size=speaker_model_size,
|
| 551 |
+
speaker_patch_size=speaker_patch_size,
|
| 552 |
+
adaln_rank=adaln_rank,
|
| 553 |
+
)
|
| 554 |
+
self.blocks.append(block)
|
| 555 |
+
|
| 556 |
+
self.out_norm = RMSNorm(model_size, norm_eps)
|
| 557 |
+
self.out_proj = nn.Linear(model_size, latent_size, bias=True)
|
| 558 |
+
|
| 559 |
+
self.head_dim = model_size // num_heads
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def forward(
|
| 564 |
+
self,
|
| 565 |
+
x: torch.Tensor,
|
| 566 |
+
t: torch.Tensor,
|
| 567 |
+
text_mask: torch.Tensor,
|
| 568 |
+
speaker_mask: torch.Tensor,
|
| 569 |
+
kv_cache_text: List[Tuple[torch.Tensor, torch.Tensor]],
|
| 570 |
+
kv_cache_speaker: List[Tuple[torch.Tensor, torch.Tensor]],
|
| 571 |
+
start_pos: int | None = None,
|
| 572 |
+
kv_cache_latent: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,
|
| 573 |
+
) -> torch.Tensor:
|
| 574 |
+
|
| 575 |
+
if start_pos is None:
|
| 576 |
+
start_pos = 0
|
| 577 |
+
|
| 578 |
+
max_pos = start_pos + x.shape[1]
|
| 579 |
+
freqs_cis = precompute_freqs_cis(self.head_dim, max_pos).to(x.device) # could cache
|
| 580 |
+
|
| 581 |
+
speaker_mask = speaker_mask[..., ::self.speaker_patch_size]
|
| 582 |
+
|
| 583 |
+
cond_embed = self.cond_module(get_timestep_embedding(t, self.timestep_embed_size))
|
| 584 |
+
cond_embed = cond_embed[:, None]
|
| 585 |
+
|
| 586 |
+
x = self.in_proj(x)
|
| 587 |
+
|
| 588 |
+
for i, block in enumerate(self.blocks):
|
| 589 |
+
x = block(
|
| 590 |
+
x=x,
|
| 591 |
+
cond_embed=cond_embed,
|
| 592 |
+
text_mask=text_mask,
|
| 593 |
+
speaker_mask=speaker_mask,
|
| 594 |
+
freqs_cis=freqs_cis,
|
| 595 |
+
kv_cache_text=kv_cache_text[i],
|
| 596 |
+
kv_cache_speaker=kv_cache_speaker[i],
|
| 597 |
+
start_pos=start_pos,
|
| 598 |
+
kv_cache_latent=kv_cache_latent[i] if kv_cache_latent is not None else None,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
x = self.out_norm(x)
|
| 602 |
+
x = self.out_proj(x)
|
| 603 |
+
|
| 604 |
+
return x.float()
|
| 605 |
+
|
| 606 |
+
def get_kv_cache_text(
|
| 607 |
+
self,
|
| 608 |
+
text_input_ids: torch.Tensor,
|
| 609 |
+
text_mask: torch.Tensor | None,
|
| 610 |
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
| 611 |
+
text_state = self.text_encoder(text_input_ids, text_mask)
|
| 612 |
+
text_state = self.text_norm(text_state)
|
| 613 |
+
return [block.attention.get_kv_cache_text(text_state) for block in self.blocks]
|
| 614 |
+
|
| 615 |
+
def get_kv_cache_speaker(
|
| 616 |
+
self,
|
| 617 |
+
speaker_latent: torch.Tensor,
|
| 618 |
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
| 619 |
+
speaker_state = self.speaker_encoder(speaker_latent)
|
| 620 |
+
speaker_state = self.speaker_norm(speaker_state)
|
| 621 |
+
return [block.attention.get_kv_cache_speaker(speaker_state) for block in self.blocks]
|
| 622 |
+
|
| 623 |
+
def get_kv_cache_latent(
|
| 624 |
+
self,
|
| 625 |
+
prefix_latent: torch.Tensor,
|
| 626 |
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
| 627 |
+
latent_state = self.latent_encoder(prefix_latent)
|
| 628 |
+
latent_state = self.latent_norm(latent_state)
|
| 629 |
+
|
| 630 |
+
seq_len = latent_state.shape[1]
|
| 631 |
+
max_pos = seq_len * self.speaker_patch_size
|
| 632 |
+
freqs_cis = precompute_freqs_cis(self.head_dim, max_pos).to(latent_state.device) # could cache
|
| 633 |
+
positions = torch.arange(seq_len, device=latent_state.device) * self.speaker_patch_size
|
| 634 |
+
freqs_latent = freqs_cis[positions]
|
| 635 |
+
|
| 636 |
+
return [block.attention.get_kv_cache_latent(latent_state, freqs_latent) for block in self.blocks]
|
| 637 |
+
|
| 638 |
+
@property
|
| 639 |
+
def device(self) -> torch.device: return next(self.parameters()).device
|
| 640 |
+
|
| 641 |
+
@property
|
| 642 |
+
def dtype(self) -> torch.dtype: return next(self.parameters()).dtype
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.9.1
|
| 2 |
+
torchaudio>=2.9.1
|
| 3 |
+
torchcodec>=0.8.1
|
| 4 |
+
huggingface-hub
|
| 5 |
+
numpy
|
| 6 |
+
safetensors
|
| 7 |
+
einops
|
| 8 |
+
gradio==5.49.1
|
sampler_presets.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Independent-High-Speaker-CFG": {
|
| 3 |
+
"num_steps": "40",
|
| 4 |
+
"cfg_scale_text": "3.0",
|
| 5 |
+
"cfg_scale_speaker": "8.0",
|
| 6 |
+
"cfg_min_t": "0.5",
|
| 7 |
+
"cfg_max_t": "1.0",
|
| 8 |
+
"truncation_factor": "1.",
|
| 9 |
+
"rescale_k": "1.",
|
| 10 |
+
"rescale_sigma": "3.0"
|
| 11 |
+
},
|
| 12 |
+
"Independent-High-Speaker-CFG-Flat": {
|
| 13 |
+
"num_steps": "40",
|
| 14 |
+
"cfg_scale_text": "3.0",
|
| 15 |
+
"cfg_scale_speaker": "8.0",
|
| 16 |
+
"cfg_min_t": "0.5",
|
| 17 |
+
"cfg_max_t": "1.0",
|
| 18 |
+
"truncation_factor": "0.8",
|
| 19 |
+
"rescale_k": "1.2",
|
| 20 |
+
"rescale_sigma": "3.0"
|
| 21 |
+
},
|
| 22 |
+
"Independent-High-CFG": {
|
| 23 |
+
"num_steps": "40",
|
| 24 |
+
"cfg_scale_text": "8.0",
|
| 25 |
+
"cfg_scale_speaker": "8.0",
|
| 26 |
+
"cfg_min_t": "0.5",
|
| 27 |
+
"cfg_max_t": "1.0",
|
| 28 |
+
"truncation_factor": "1.",
|
| 29 |
+
"rescale_k": "1.",
|
| 30 |
+
"rescale_sigma": "3.0"
|
| 31 |
+
},
|
| 32 |
+
"Independent-High-CFG-Flat": {
|
| 33 |
+
"num_steps": "40",
|
| 34 |
+
"cfg_scale_text": "8.0",
|
| 35 |
+
"cfg_scale_speaker": "8.0",
|
| 36 |
+
"cfg_min_t": "0.5",
|
| 37 |
+
"cfg_max_t": "1.0",
|
| 38 |
+
"truncation_factor": "0.8",
|
| 39 |
+
"rescale_k": "1.2",
|
| 40 |
+
"rescale_sigma": "3.0"
|
| 41 |
+
},
|
| 42 |
+
"Independent-Low-CFG": {
|
| 43 |
+
"num_steps": "40",
|
| 44 |
+
"cfg_scale_text": "3.0",
|
| 45 |
+
"cfg_scale_speaker": "3.0",
|
| 46 |
+
"cfg_min_t": "0.5",
|
| 47 |
+
"cfg_max_t": "1.0",
|
| 48 |
+
"truncation_factor": "1.",
|
| 49 |
+
"rescale_k": "1.",
|
| 50 |
+
"rescale_sigma": "3.0"
|
| 51 |
+
},
|
| 52 |
+
"Independent-Low-CFG-Flat": {
|
| 53 |
+
"num_steps": "40",
|
| 54 |
+
"cfg_scale_text": "3.0",
|
| 55 |
+
"cfg_scale_speaker": "3.0",
|
| 56 |
+
"cfg_min_t": "0.5",
|
| 57 |
+
"cfg_max_t": "1.0",
|
| 58 |
+
"truncation_factor": "0.8",
|
| 59 |
+
"rescale_k": "1.2",
|
| 60 |
+
"rescale_sigma": "3.0"
|
| 61 |
+
}
|
| 62 |
+
}
|
text_presets.txt
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Reading | [S1] The old lighthouse keeper had seen many storms in his thirty years on the rock, but nothing like this. The fog rolled in thick as wool, swallowing the beam of light before it could reach the churning waves below. Then he heard it, three short bells from the channel, where no ship should be at this hour. He grabbed his lantern and peered into the mist, his heart pounding. Something was out there, something that shouldn't exist.
|
| 2 |
+
|
| 3 |
+
Reading | [S1] Deep beneath the ocean's surface, where sunlight fades to perpetual twilight, extraordinary creatures have evolved in ways that defy imagination. Bioluminescent jellyfish pulse with ethereal blue light, while giant squid hunt in the crushing darkness. At depths of over two miles, the pressure is immense, enough to collapse a submarine, yet life persists.
|
| 4 |
+
|
| 5 |
+
Reading | [S1] The telegram arrived on a Tuesday morning in June, nineteen forty-three. Margaret's hands trembled as she tore open the envelope, dreading the words she knew might be inside. Her brother had shipped out to North Africa six months ago, and his letters had grown increasingly sparse.
|
| 6 |
+
|
| 7 |
+
Cartoon | [S1] After giving everything some more thought, I've decided it's in the best interest of humanity to acquire Nexus AI. (laughs) I've spoken with the CEO and he's on board. Well (laughs), at least that's the impression he gave initially.
|
| 8 |
+
|
| 9 |
+
Single (Disfluent) | [S1] ... explore how we can design, create interfaces that are not confusing, but at the same time can be powerful. Um, you know, I think, uh, in the, the famous, um, usability book, it's, uh, it's this, um, um, oh, geez, I'm, I'm blanking on the term, uh, uh, the, the rule about, um, uh, it's like the simplicity rule. I can't recall. Oh, cognitive load maybe.
|
| 10 |
+
|
| 11 |
+
Single (Disfluent) | [S1] Uh, complacency when the motivation isn't structured properly. Like for example, if you, if you're in the cor- if you work in the corporation for many years, a lot of corporate employees, they just, they're, they're aiming for that stock vesting and they're, they're doing just a sufficient job to, to, to reach that vesting and, and they don't, they're not performing any better than that. Um, and so I think, um, that showed me an important insight. Yeah.
|
| 12 |
+
|
| 13 |
+
Single (Disfluent) | [S1] We see the pattern of revelations, major shifts. I think Neptune in Pisces, which that transit has been happening all of 2021, and Neptune will remain in the sign of Pisces until March of 2029. So it's several years more of this transit. And what it brings is a lot of things, you know, the thing that I tend to emphasize is the profound dissolution or profound changes
|
| 14 |
+
|
| 15 |
+
Single (Disfluent) | [S1] I asked her, "Do you have like a phrase you use," and she mentioned she actually does. Like when things get tense, when there's like a moment, like if her, if her roommate is like venting about work drama or just like is stressed, and her, her roommate like deals with anxiety, I'm like, "Oh, this is probably how it feels to live with me." But, um, and like if, if, if things are rough, like she'll internally just like use this practice where she's like, like, "Not my problem, not mine to carry, not mine to handle, not mine to change." Like she'll sort of repeat that. So that's interesting.
|
| 16 |
+
|
| 17 |
+
Single (Disfluent) | [S1] If I examine the, the, if, if you examine the range of options, uh, beginning from, like, say, individual all the way, right? There will be some revenue stream, uh, there will be some purchase, there'll be some hardware profit margin for someone who creates a smart product, um, uh, there will be memberships, personal and business, uh, and then there'll be usage-based, right? So I still believe that that's kinda how, those are all the metrics. To your point, what is a membership? Up to now, folks
|
| 18 |
+
|
| 19 |
+
Single (Disfluent) | [S1] I think, if, if we can keep it under 25 points allowed, sure, our odds improve significantly. We wouldn't need to put up huge numbers ourselves, or at least that's the theory. And I should, I want to share some other stats which might be a bit outside our current discussion, but regarding this compared to 2018, the team's final four games that year, they managed 18 points total.
|
| 20 |
+
|
| 21 |
+
Singing | [S1] (singing) Amazing grace, how sweet the sound, that saved a wretch like me. I once was lost, but now am found, was blind, but now I see.
|
| 22 |
+
|
| 23 |
+
Conversation | [S1] Alright then. So, so 18 years you spent in that, uh, in that role, but alongside that in, in, was it while you were working that position in '93, you started doing some work with the network? [S2] Uh, yes. It was somewhere around '93. I, I, I played tennis pretty well, you know? I, I, I competed as a tennis player. And the, I got a chance to do some broadcasting over in Brisbane.
|
| 24 |
+
|
| 25 |
+
Conversation | [S1] ... that will provide the analytics component- [S2] Right. [S1] ... to ideally get you to adopt some of their other tools. And- [S2] (laughs) [S1] ... some of those features are valuable too. [S2] That's interesting. [S1] Mailchimp, I mean, that's campaign manage-, uh, not exactly campaign management, but messaging platforms. [S2] Uh-huh. [S1] The, the companies that are, you know,
|
| 26 |
+
|
| 27 |
+
Conversation | [S1] They were like, they were pumped for it, going wild for it, and it disappeared immediately. [S2] Yeah, I think it's about people understanding what's available first. Um... [S1] I think the finish on that one too was really nice. [S2] Yeah. [S1] I mean, that was pretty awesome. [S2] Have you seen those new editions?
|
| 28 |
+
|
| 29 |
+
Conversation | [S1] He was just practicing with them and they were on rotation. [S2] So that was probably in January. [S1] I think startup stereotypes, there is some like that, but some of them, I think they need to be changed. Like we don't all work twenty-hour days. [S2] No, they just need to, it's called not, it's based in Silicon Valley. [S1] Yeah. [S2] But the stereotypes would apply if they, it was called Techlife- [S1] Palo Alto. [S2] ... Cupertino or Mountain View, California.
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
Conversation | [S1] That's a nice overview. [S2] We were at the downtown cinema. [S1] By that, you mean the one in Riverside? [S2] Yeah. [S1] Yeah. So not exactly downtown. [S2] Not exactly downtown, yeah. [S1] I know a little bit about that area. [S2] (laughs) [S1] You know, Millbrook doesn't have a cinema. [S2] (laughs) It's the closest one for us. It's the closest. [S1] Yeah, that's true. [S2] The most nearby. [S1] Riverside is nearby. [S2] Riverside's close. [S1] That's fair. [S2] Support nearby. [S1] You can say, say Riverside, definitely. [S2] Well, yeah, fair enough.
|
| 33 |
+
|
| 34 |
+
Conversation | [S1] But they also, they also discovered, um, they also discovered like patterns in the desert, um, near Peru, like in the Atacama Desert. [S2] Yeah. [S1] Um, and like, it was like, of like perfectly, like, geo- geometric shapes. And they're like, "Yo, this is definitely not like formed by wind. This has to be artificial." [S2] Yeah, it's too precise.
|
| 35 |
+
|
| 36 |
+
Conversation | [S1] 'Cause I, yeah, there, there has to be a way that they can just make the, the system recognize that, no, you did not earn this- [S2] (laughs) [S1] ... on your own. You still have to go and complete one if you want it for your own- [S2] Right. [S1] ... like, profile. [S2] Right. Mm-hmm. [S1] So, yeah. [S2] Um, yeah. So let's actually move into multiplayer.
|
| 37 |
+
|
| 38 |
+
Conversation | [S1] Yeah. [S2] Yeah. TRS as a whole is just relaxed. [S1] But anyway, you know that Mirror app that launched and then got removed like a month later? [S2] Mirror, what, like, to your future? [S1] Yeah. [S2] Oh. [S1] So basically, there was an app, there's a show coming out. [S2] This is a show. [S1] Coming, I don't know what it is. [S2] Yeah, yeah, yeah. [S1] Like 2026 or something. Basically, Marcus, have you heard about this? [S2] I'm sorry, I don't know. No, I don't have an, it's an app- [S1] Okay, so I'll explain. I'll explain. [S2] Yeah. [S1] For context. So there's this app that launched in terms of the show called Mirror.
|
| 39 |
+
|
| 40 |
+
Conversation | [S1] Jamie Patterson, right? [S2] No, I know where- [S1] I know where- [S2] ... Patterson works as well. I know where- [S1] I know- I know he used to work near- on this street, and this is a weird street. [S2] The only person who I don't know where they work, Jamie. But anyway, why are we even talking about who works where? [S1] It was a- it was- it was a really weird street name where Jamie worked. [S2] I- I drove past this street on my commute. [S1] No, you didn't. [S2] Yeah, I did. [S1] No, you drove past the street that my street is down the street of. [S2] Nice. There's, like, one street in Oakfield, I think I'll be able to find it, mate.
|