Hsopgamers commited on
Commit
176c235
Β·
verified Β·
1 Parent(s): ffc7d09

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ audio_prompts/EARS[[:space:]]p004[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
37
+ audio_prompts/EARS[[:space:]]p005[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
38
+ audio_prompts/EARS[[:space:]]p028[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
39
+ audio_prompts/EARS[[:space:]]p036[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
40
+ audio_prompts/expresso_02_ex03-ex01_calm_005.mp3 filter=lfs diff=lfs merge=lfs -text
41
+ audio_prompts/freesound_demon_chant(use_forcespeaker).mp3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ __pycache__/
3
+ *.py[cod]
4
+ .venv/
5
+ venv/
6
+ .env
7
+ .idea/
8
+ .vscode/
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Jordan Darefsky
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
README.md CHANGED
@@ -1,12 +1,151 @@
1
  ---
2
- title: Echo Tts
3
- emoji: πŸ†
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.6.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: echo-tts
3
+ app_file: gradio_app.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.49.1
 
 
6
  ---
7
+ # Echo-TTS
8
 
9
+ A multi-speaker text-to-speech model with speaker reference conditioning. See the [blog post](https://jordandarefsky.com/blog/2025/echo/) for technical details.
10
+
11
+ **Model:** [jordand/echo-tts-base](https://huggingface.co/jordand/echo-tts-base) | **Demo:** [echo-tts-preview](https://huggingface.co/spaces/jordand/echo-tts-preview)
12
+
13
+ This work was made possible by the TPU Research Cloud (TRC).
14
+
15
+ ## Responsible Use
16
+
17
+ Don't use this model to:
18
+ - Impersonate real people without their consent
19
+ - Generate deceptive audio (e.g., fraud, misinformation, deepfakes)
20
+
21
+ You are responsible for complying with local laws regarding biometric data and voice cloning.
22
+
23
+ ## Installation
24
+
25
+ ```bash
26
+ pip install -r requirements.txt
27
+ ```
28
+
29
+ Requires Python 3.10+ and a CUDA-capable GPU with at least 8GB VRAM.
30
+
31
+ ## Quick Start
32
+
33
+ ### Gradio UI
34
+
35
+ ```bash
36
+ python gradio_app.py
37
+ ```
38
+
39
+ ### Python API
40
+
41
+ ```python
42
+ from inference import (
43
+ load_model_from_hf,
44
+ load_fish_ae_from_hf,
45
+ load_pca_state_from_hf,
46
+ load_audio,
47
+ sample_pipeline,
48
+ sample_euler_cfg_independent_guidances,
49
+ )
50
+ from functools import partial
51
+ import torchaudio
52
+
53
+ # Load models (downloads from HuggingFace on first run)
54
+ model = load_model_from_hf(delete_blockwise_modules=True)
55
+ fish_ae = load_fish_ae_from_hf()
56
+ pca_state = load_pca_state_from_hf()
57
+
58
+ # Load speaker reference (or set to None for no reference)
59
+ speaker_audio = load_audio("speaker.wav").cuda()
60
+
61
+ # Configure sampler
62
+ sample_fn = partial(
63
+ sample_euler_cfg_independent_guidances,
64
+ num_steps=40,
65
+ cfg_scale_text=3.0,
66
+ cfg_scale_speaker=8.0,
67
+ cfg_min_t=0.5,
68
+ cfg_max_t=1.0,
69
+ truncation_factor=None,
70
+ rescale_k=None,
71
+ rescale_sigma=None,
72
+ speaker_kv_scale=None,
73
+ speaker_kv_max_layers=None,
74
+ speaker_kv_min_t=None,
75
+ sequence_length=640, # (~30 seconds)
76
+ )
77
+
78
+ # Generate
79
+ text = "[S1] Hello, this is a test of the Echo TTS model."
80
+ audio_out, _ = sample_pipeline(
81
+ model=model,
82
+ fish_ae=fish_ae,
83
+ pca_state=pca_state,
84
+ sample_fn=sample_fn,
85
+ text_prompt=text,
86
+ speaker_audio=speaker_audio,
87
+ rng_seed=0,
88
+ )
89
+
90
+ torchaudio.save("output.wav", audio_out[0].cpu(), 44100)
91
+ ```
92
+
93
+ See also:
94
+ - `inference.py` -- lower-level usage example at the bottom of the file
95
+ - `inference_blockwise.py` -- examples of blockwise/continuation generation
96
+
97
+ ## Low VRAM (8GB)
98
+
99
+ In `gradio_app.py`, adjust:
100
+
101
+ ```python
102
+ FISH_AE_DTYPE = torch.bfloat16 # instead of float32
103
+ DEFAULT_SAMPLE_LATENT_LENGTH = 576 # (< 640 depending on what fits) instead of 640
104
+ ```
105
+
106
+ ## Tips
107
+
108
+ ### Generation Length
109
+
110
+ Echo is trained to generate up to 30 seconds of audio (640 latents) given text and reference audio. Since the supplied text always corresponded to ≀30 seconds of audio during training, the model will attempt to fit any text prompt at inference into the 30 seconds of generated audio (and thus, e.g., long text prompts may result in faster speaking rates). On the other hand, shorter text prompts will work and will produce shorter outputs (as the model generates latent padding automatically).
111
+
112
+ If "Sample Latent Length" (in Custom Shapes in gradio)/sequence_length is set to less than 640, the model will attempt to generate the prefix corresponding to that length. I.e., if you set this to 320, and supply ~30 seconds worth of text, the model will likely generate the first half of the text (rather than try to fit the entirety of the text into the first 15 seconds).
113
+
114
+ ### Reference Audio
115
+
116
+ You can condition on up to 5 minutes of reference audio, but shorter clips (e.g., 10 seconds or shorter) work well too.
117
+
118
+ ### Force Speaker (KV Scaling)
119
+
120
+ Sometimes out-of-distribution text for a given reference speaker will cause the model to generate a different speaker entirely. Enabling "Force Speaker" (which scales speaker KV for a portion of timesteps, default scale 1.5) generally fixes this. However, high values may introduce artifacts or "overconditioning." Aim for the lowest scale that produces the correct speaker: 1.0 is baseline, 1.5 is the default when enabled and will usually force the speaker, but lower values (e.g., 1.3, 1.1) may suffice.
121
+
122
+ ### Text Prompt Format
123
+
124
+ Text prompts use the format from [WhisperD](https://huggingface.co/jordand/whisper-d-v1a). Colons, semicolons, and emdashes are normalized to commas (see inference.py tokenizer_encode) by default, and "[S1] " will be added to the beginning of the prompt if not already present. Commas generally function as pauses. Exclamation points (and other non-bland punctuation) may lead to increased expressiveness but also potentially lower quality on occasion; improving controllability is an important direction for future work.
125
+
126
+ The included text presets are stylistically in-distribution with the WhisperD transcription style.
127
+
128
+ ### Blockwise Generation
129
+
130
+ `inference_blockwise.py` includes blockwise sampling, which allows generating audio in smaller blocks as well as producing continuations of existing audio (where the prefix and continuation are up to 30 seconds combined). The model released on HF is a fully fine-tuned model (not the LoRA as described in the blog). Blockwise generation enables audio streaming (not included in current code) since the S1-DAC decoder is causal. Blockwise functionality hasn't been thoroughly tested and may benefit from different (e.g., smaller) CFG scales.
131
+
132
+ ## License
133
+
134
+ Code in this repo is MIT‑licensed except where file headers specify otherwise (e.g., autoencoder.py is Apache‑2.0).
135
+
136
+ Regardless of our model license, audio outputs are CC-BY-NC-SA-4.0 due to the dependency on the Fish Speech S1-DAC autoencoder, which is CC-BY-NC-SA-4.0.
137
+
138
+ We have chosen to release the Echo-TTS weights under CC-BY-NC-SA-4.0.
139
+
140
+ For included audio prompts, see `audio_prompts/LICENSE`.
141
+
142
+ ## Citation
143
+
144
+ ```bibtex
145
+ @misc{darefsky2025echo,
146
+ author = {Darefsky, Jordan},
147
+ title = {Echo-TTS},
148
+ year = {2025},
149
+ url = {https://jordandarefsky.com/blog/2025/echo/}
150
+ }
151
+ ```
audio_prompts/EARS p004 freeform.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68947a209bc11064f749ca0a61b7959243df83565a0e462b87dfc0ffe03aa7b0
3
+ size 1526439
audio_prompts/EARS p005 freeform.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07344d073eb3e22c249ebfe15f31f4ba63fd9f17c71aeee93da199ff3b53fc45
3
+ size 1351147
audio_prompts/EARS p028 freeform.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8351eed5982f1fb5763a475c0fb69dba98a4bb49b0f2bbab12b978ff2b0fedeb
3
+ size 1211565
audio_prompts/EARS p036 freeform.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce77dbb86ea7c29edf2b9804ce9c9315334e9cfeef532dc0c50898a09bae1583
3
+ size 1227585
audio_prompts/LICENSE ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The audio files in this folder are provided for demonstration purposes and
2
+ are sourced from the following datasets. Please refer to their original
3
+ licenses for terms of use.
4
+
5
+ EARS Dataset (CC-BY-NC-4.0)
6
+ ---------------------------
7
+ - EARS p004 freeform.mp3
8
+ - EARS p005 freeform.mp3
9
+ - EARS p028 freeform.mp3
10
+ - EARS p036 freeform.mp3
11
+
12
+ Source: https://github.com/facebookresearch/ears_dataset
13
+
14
+ Expresso Dataset (CC-BY-NC-4.0)
15
+ -------------------------------
16
+ - expresso_02_ex03-ex01_calm_005.wav
17
+
18
+ Source: https://speechbot.github.io/expresso/
19
+
20
+ Freesound (CC0)
21
+ ---------------
22
+ - freesound_demon_chant(use_forcespeaker).mp3
23
+
24
+ Source: https://freesound.org/s/419507/
25
+ Author: DylanTheFish
26
+
audio_prompts/expresso_02_ex03-ex01_calm_005.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98855b5b8b6c265a643edeb23ce5cd772391cb90754822e2a0370ea5188225f5
3
+ size 4802350
audio_prompts/freesound_demon_chant(use_forcespeaker).mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:471f67fff5ea613ec4617b9822b1396da123a1133f199925436a2c40e5d1eb91
3
+ size 303438
autoencoder.py ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # This file contains portions adapted from:
4
+ # β€’ Descript Audio Codec (DAC) β€” MIT License (full text appended below)
5
+ # β€’ Fish-Speech S1 DAC Autoencoder β€” reference implementation (Apache-2.0 / CC-BY-NC),
6
+ # rewritten here in a single-file Torch module for interoperability and transparency.
7
+ #
8
+ # OVERALL LICENSE (this file): Apache-2.0, except where explicitly marked:
9
+ # # SPDX-License-Identifier: MIT
10
+ # Keep these notices and the embedded MIT text if you redistribute this file.
11
+
12
+ # NOTE
13
+ # Self-contained autoencoder implementation of Fish-S1-DAC (inlining DAC code to avoid dependencies).
14
+ # Code in this module has been largely copy-and-pasted from the Fish-S1-DAC and DAC repositories,
15
+ # and refactored with help from ChatGPT/Claude (these models also helped with licensing).
16
+ # Thus, it differs stylistically from the rest of the codebase (and is likely internally inconsistent as well).
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ from dataclasses import dataclass
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+ import torch
26
+ from torch import Tensor, nn
27
+ from torch.nn import functional as F
28
+ from torch.nn.utils.parametrizations import weight_norm
29
+ from torch.nn.utils.parametrize import remove_parametrizations
30
+
31
+ from einops import rearrange
32
+
33
+
34
+ # --------------------------------------------------------------------
35
+ # Shared helpers
36
+ # --------------------------------------------------------------------
37
+
38
+ def find_multiple(n: int, k: int) -> int:
39
+ return n if n % k == 0 else n + k - (n % k)
40
+
41
+ def unpad1d(x: Tensor, paddings: Tuple[int, int]) -> Tensor:
42
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
43
+ padding_left, padding_right = paddings
44
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
45
+ assert (padding_left + padding_right) <= x.shape[-1]
46
+ end = x.shape[-1] - padding_right
47
+ return x[..., padding_left:end]
48
+
49
+ def get_extra_padding_for_conv1d(
50
+ x: Tensor, kernel_size: int, stride: int, padding_total: int = 0
51
+ ) -> int:
52
+ """See pad_for_conv1d; enough right pad so striding evenly covers length."""
53
+ length = x.shape[-1]
54
+ n_frames = (length - kernel_size + padding_total) / stride + 1
55
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
56
+ return ideal_length - length
57
+
58
+ def pad1d(
59
+ x: Tensor,
60
+ paddings: Tuple[int, int],
61
+ mode: str = "zeros",
62
+ value: float = 0.0,
63
+ ) -> Tensor:
64
+ """
65
+ Reflect‑safe 1D pad: if reflect would underflow on small inputs, insert
66
+ temporary right zero-pad before reflecting.
67
+ """
68
+ length = x.shape[-1]
69
+ padding_left, padding_right = paddings
70
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
71
+ if mode == "reflect":
72
+ max_pad = max(padding_left, padding_right)
73
+ extra_pad = 0
74
+ if length <= max_pad:
75
+ extra_pad = max_pad - length + 1
76
+ x = F.pad(x, (0, extra_pad))
77
+ padded = F.pad(x, (padding_left, padding_right), mode, value)
78
+ end = padded.shape[-1] - extra_pad
79
+ return padded[..., :end]
80
+ else:
81
+ return F.pad(x, (padding_left, padding_right), mode, value)
82
+
83
+
84
+ # --------------------------------------------------------------------
85
+ # DAC Layers (adapted) β€” MIT
86
+ # Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/layers.py
87
+ # SPDX-License-Identifier: MIT
88
+ # --------------------------------------------------------------------
89
+
90
+ def WNConv1d(*args, **kwargs):
91
+ return weight_norm(nn.Conv1d(*args, **kwargs))
92
+
93
+ def WNConvTranspose1d(*args, **kwargs):
94
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
95
+
96
+ @torch.jit.script
97
+ def snake(x: Tensor, alpha: Tensor) -> Tensor:
98
+ shape = x.shape
99
+ x = x.reshape(shape[0], shape[1], -1)
100
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
101
+ x = x.reshape(shape)
102
+ return x
103
+
104
+ class Snake1d(nn.Module):
105
+ def __init__(self, channels: int):
106
+ super().__init__()
107
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
108
+ def forward(self, x: Tensor) -> Tensor:
109
+ return snake(x, self.alpha)
110
+
111
+ # --------------------------------------------------------------------
112
+ # DAC Vector Quantize (adapted) β€” MIT
113
+ # Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/quantize.py
114
+ # SPDX-License-Identifier: MIT
115
+ # --------------------------------------------------------------------
116
+
117
+ class VectorQuantize(nn.Module):
118
+ """
119
+ VQ with factorized, l2-normalized codes (ViT‑VQGAN style).
120
+ I/O in (B, D, T).
121
+ """
122
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
123
+ super().__init__()
124
+ self.codebook_size = codebook_size
125
+ self.codebook_dim = codebook_dim
126
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
127
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
128
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
129
+
130
+ def forward(self, z: Tensor):
131
+ z_e = self.in_proj(z) # (B, D, T)
132
+ z_q, indices = self.decode_latents(z_e)
133
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
134
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
135
+ z_q = z_e + (z_q - z_e).detach() # straight‑through
136
+ z_q = self.out_proj(z_q)
137
+ return z_q, commitment_loss, codebook_loss, indices, z_e
138
+
139
+ def embed_code(self, embed_id: Tensor) -> Tensor:
140
+ return F.embedding(embed_id, self.codebook.weight)
141
+
142
+ def decode_code(self, embed_id: Tensor) -> Tensor:
143
+ return self.embed_code(embed_id).transpose(1, 2)
144
+
145
+ def decode_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
146
+ encodings = rearrange(latents, "b d t -> (b t) d")
147
+ codebook = self.codebook.weight
148
+ encodings = F.normalize(encodings)
149
+ codebook = F.normalize(codebook)
150
+ dist = (
151
+ encodings.pow(2).sum(1, keepdim=True)
152
+ - 2 * encodings @ codebook.t()
153
+ + codebook.pow(2).sum(1, keepdim=True).t()
154
+ )
155
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
156
+ z_q = self.decode_code(indices)
157
+ return z_q, indices
158
+
159
+
160
+ class ResidualVectorQuantize(nn.Module):
161
+ """SoundStream-style residual VQ stack."""
162
+ def __init__(
163
+ self,
164
+ input_dim: int = 512,
165
+ n_codebooks: int = 9,
166
+ codebook_size: int = 1024,
167
+ codebook_dim: Union[int, List[int]] = 8,
168
+ quantizer_dropout: float = 0.0,
169
+ ):
170
+ super().__init__()
171
+ if isinstance(codebook_dim, int):
172
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
173
+
174
+ self.n_codebooks = n_codebooks
175
+ self.codebook_dim = codebook_dim
176
+ self.codebook_size = codebook_size
177
+
178
+ self.quantizers = nn.ModuleList([
179
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
180
+ for i in range(n_codebooks)
181
+ ])
182
+ self.quantizer_dropout = quantizer_dropout
183
+
184
+ def forward(self, z: Tensor, n_quantizers: Optional[int] = None):
185
+ z_q = 0
186
+ residual = z
187
+ commitment_loss = 0
188
+ codebook_loss = 0
189
+
190
+ codebook_indices = []
191
+ latents = []
192
+
193
+ if n_quantizers is None:
194
+ n_quantizers = self.n_codebooks
195
+ if self.training:
196
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
197
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
198
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
199
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
200
+ n_quantizers = n_quantizers.to(z.device)
201
+
202
+ for i, quantizer in enumerate(self.quantizers):
203
+ if self.training is False and i >= n_quantizers:
204
+ break
205
+
206
+ z_q_i, commit_i, codebk_i, indices_i, z_e_i = quantizer(residual)
207
+
208
+ mask = (torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers)
209
+ z_q = z_q + z_q_i * mask[:, None, None]
210
+ residual = residual - z_q_i
211
+
212
+ commitment_loss += (commit_i * mask).mean()
213
+ codebook_loss += (codebk_i * mask).mean()
214
+
215
+ codebook_indices.append(indices_i)
216
+ latents.append(z_e_i)
217
+
218
+ codes = torch.stack(codebook_indices, dim=1)
219
+ latents = torch.cat(latents, dim=1)
220
+
221
+ return z_q, codes, latents, commitment_loss, codebook_loss
222
+
223
+ def from_codes(self, codes: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
224
+ z_q = 0.0
225
+ z_p = []
226
+ n_codebooks = codes.shape[1]
227
+ for i in range(n_codebooks):
228
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
229
+ z_p.append(z_p_i)
230
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
231
+ z_q = z_q + z_q_i
232
+ return z_q, torch.cat(z_p, dim=1), codes
233
+
234
+ def from_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
235
+ z_q = 0
236
+ z_p = []
237
+ codes = []
238
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
239
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
240
+ for i in range(n_codebooks):
241
+ j, k = dims[i], dims[i + 1]
242
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
243
+ z_p.append(z_p_i)
244
+ codes.append(codes_i)
245
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
246
+ z_q = z_q + z_q_i
247
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
248
+
249
+
250
+ # --------------------------------------------------------------------
251
+ # S1 DAC rvq
252
+ # --------------------------------------------------------------------
253
+
254
+ @dataclass
255
+ class VQResult:
256
+ z: Tensor
257
+ codes: Tensor
258
+ latents: Tensor
259
+ codebook_loss: Tensor
260
+ commitment_loss: Tensor
261
+ semantic_distill_z: Optional[Tensor] = None
262
+
263
+
264
+ class CausalConvNet(nn.Module):
265
+ def __init__(
266
+ self,
267
+ in_channels,
268
+ out_channels,
269
+ kernel_size,
270
+ dilation=1,
271
+ stride=1,
272
+ groups=1,
273
+ padding=None,
274
+ ):
275
+ super().__init__()
276
+ self.conv = nn.Conv1d(
277
+ in_channels, out_channels, kernel_size,
278
+ stride=stride, dilation=dilation, groups=groups,
279
+ )
280
+ self.stride = stride
281
+ self.kernel_size = (kernel_size - 1) * dilation + 1
282
+ self.dilation = dilation
283
+ self.padding = self.kernel_size - self.stride
284
+
285
+ def forward(self, x: Tensor) -> Tensor:
286
+ pad = self.padding
287
+ extra = get_extra_padding_for_conv1d(x, self.kernel_size, self.stride, pad)
288
+ x = pad1d(x, (pad, extra), mode="constant", value=0)
289
+ return self.conv(x).contiguous()
290
+
291
+ def weight_norm(self, name="weight", dim=0):
292
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
293
+ return self
294
+
295
+ def remove_weight_norm(self):
296
+ self.conv = remove_parametrizations(self.conv)
297
+ return self
298
+
299
+
300
+ class CausalTransConvNet(nn.Module):
301
+ def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None):
302
+ super().__init__()
303
+ self.conv = nn.ConvTranspose1d(
304
+ in_channels, out_channels, kernel_size,
305
+ stride=stride, dilation=dilation
306
+ )
307
+ self.stride = stride
308
+ self.kernel_size = kernel_size
309
+
310
+ def forward(self, x: Tensor) -> Tensor:
311
+ x = self.conv(x)
312
+ pad = self.kernel_size - self.stride
313
+ padding_right = math.ceil(pad)
314
+ padding_left = pad - padding_right
315
+ x = unpad1d(x, (padding_left, padding_right))
316
+ return x.contiguous()
317
+
318
+ def weight_norm(self, name="weight", dim=0):
319
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
320
+ return self
321
+
322
+ def remove_weight_norm(self):
323
+ self.conv = remove_parametrizations(self.conv)
324
+ return self
325
+
326
+
327
+ def CausalWNConv1d(*args, **kwargs):
328
+ return CausalConvNet(*args, **kwargs).weight_norm()
329
+
330
+ def CausalWNConvTranspose1d(*args, **kwargs):
331
+ return CausalTransConvNet(*args, **kwargs).weight_norm()
332
+
333
+ class ConvNeXtBlock(nn.Module):
334
+ r"""ConvNeXt Block (1D).
335
+ DwConv -> (N, C, L) β†’ (N, L, C) -> LN -> Linear -> GELU -> Linear -> (N, C, L) with residual
336
+ """
337
+ def __init__(
338
+ self,
339
+ dim: int,
340
+ layer_scale_init_value: float = 1e-6,
341
+ mlp_ratio: float = 4.0,
342
+ kernel_size: int = 7,
343
+ dilation: int = 1,
344
+ ):
345
+ super().__init__()
346
+ convnet_type = CausalConvNet
347
+ self.dwconv = convnet_type(
348
+ dim, dim, kernel_size=kernel_size,
349
+ groups=dim, dilation=dilation,
350
+ ) # depthwise conv
351
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
352
+ self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim))
353
+ self.act = nn.GELU()
354
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
355
+ self.gamma = (
356
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
357
+ if layer_scale_init_value > 0 else None
358
+ )
359
+
360
+ def forward(self, x: Tensor, apply_residual: bool = True) -> Tensor:
361
+ inp = x
362
+ x = self.dwconv(x)
363
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
364
+ x = self.norm(x)
365
+ x = self.pwconv1(x)
366
+ x = self.act(x)
367
+ x = self.pwconv2(x)
368
+ if self.gamma is not None:
369
+ x = self.gamma * x
370
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
371
+ if apply_residual:
372
+ x = inp + x
373
+ return x
374
+
375
+
376
+ class DownsampleResidualVectorQuantize(nn.Module):
377
+ def __init__(
378
+ self,
379
+ input_dim: int = 1024,
380
+ n_codebooks: int = 9,
381
+ codebook_dim: int = 8,
382
+ quantizer_dropout: float = 0.5,
383
+ codebook_size: int = 1024,
384
+ semantic_codebook_size: int = 4096,
385
+ downsample_factor: Tuple[int, ...] = (2, 2),
386
+ downsample_dims: Optional[Tuple[int, ...]] = None,
387
+ pre_module: Optional[nn.Module] = None,
388
+ post_module: Optional[nn.Module] = None,
389
+ semantic_predictor_module: Optional[nn.Module] = None,
390
+ ):
391
+ super().__init__()
392
+
393
+ if downsample_dims is None:
394
+ downsample_dims = tuple(input_dim for _ in range(len(downsample_factor)))
395
+
396
+ all_dims = (input_dim,) + tuple(downsample_dims)
397
+
398
+ self.semantic_quantizer = ResidualVectorQuantize(
399
+ input_dim=input_dim,
400
+ n_codebooks=1,
401
+ codebook_size=semantic_codebook_size,
402
+ codebook_dim=codebook_dim,
403
+ quantizer_dropout=0.0,
404
+ )
405
+
406
+ self.quantizer = ResidualVectorQuantize(
407
+ input_dim=input_dim,
408
+ n_codebooks=n_codebooks,
409
+ codebook_size=codebook_size,
410
+ codebook_dim=codebook_dim,
411
+ quantizer_dropout=quantizer_dropout,
412
+ )
413
+
414
+ convnet_type = CausalConvNet
415
+ transconvnet_type = CausalTransConvNet
416
+
417
+ self.downsample = nn.Sequential(
418
+ *[
419
+ nn.Sequential(
420
+ convnet_type(all_dims[idx], all_dims[idx + 1], kernel_size=factor, stride=factor),
421
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
422
+ )
423
+ for idx, factor in enumerate(downsample_factor)
424
+ ]
425
+ )
426
+
427
+ self.upsample = nn.Sequential(
428
+ *[
429
+ nn.Sequential(
430
+ transconvnet_type(all_dims[idx + 1], all_dims[idx], kernel_size=factor, stride=factor),
431
+ ConvNeXtBlock(dim=all_dims[idx]),
432
+ )
433
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
434
+ ]
435
+ )
436
+
437
+ self.apply(self._init_weights)
438
+ self.pre_module = pre_module if pre_module is not None else nn.Identity()
439
+ self.post_module = post_module if post_module is not None else nn.Identity()
440
+ self.semantic_predictor_module = (
441
+ semantic_predictor_module if semantic_predictor_module is not None else nn.Identity()
442
+ )
443
+
444
+ @staticmethod
445
+ def _init_weights(m):
446
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
447
+ nn.init.trunc_normal_(m.weight, std=0.02)
448
+ if getattr(m, "bias", None) is not None:
449
+ nn.init.constant_(m.bias, 0)
450
+
451
+ def forward(self, z: Tensor, n_quantizers: Optional[int] = None, semantic_len: Optional[Tensor] = None, **kwargs):
452
+ # z: (B, D, T)
453
+ original_shape = z.shape
454
+ if semantic_len is None:
455
+ semantic_len = torch.LongTensor([z.shape[-1]])
456
+
457
+ z = self.downsample(z)
458
+ z = self.pre_module(z) # (B, D, T) or (B, T, D) depending on module; original uses channels-first in/out
459
+
460
+ semantic_z, semantic_codes, semantic_latents, semantic_commitment_loss, semantic_codebook_loss = \
461
+ self.semantic_quantizer(z)
462
+ residual_z = z - semantic_z
463
+ residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(residual_z, n_quantizers=n_quantizers)
464
+ z = semantic_z + residual_z
465
+ commitment_loss = commitment_loss + semantic_commitment_loss
466
+ codebook_loss = codebook_loss + semantic_codebook_loss
467
+ codes = torch.cat([semantic_codes, codes], dim=1)
468
+ latents = torch.cat([semantic_latents, latents], dim=1)
469
+ z = self.post_module(z)
470
+ z = self.upsample(z)
471
+
472
+ # Pad or crop z to match original shape (time dimension)
473
+ diff = original_shape[-1] - z.shape[-1]
474
+ right = 0
475
+ left = abs(diff) - right
476
+ if diff > 0:
477
+ z = F.pad(z, (left, right))
478
+ elif diff < 0:
479
+ z = z[..., left:]
480
+
481
+ return VQResult(
482
+ z=z, codes=codes, latents=latents,
483
+ commitment_loss=commitment_loss, codebook_loss=codebook_loss,
484
+ )
485
+
486
+ def decode(self, indices: Tensor) -> Tensor:
487
+ new_indices = torch.zeros_like(indices)
488
+ new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.semantic_quantizer.codebook_size - 1)
489
+ new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.codebook_size - 1)
490
+
491
+ z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0]
492
+ z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0]
493
+ z_q = z_q_semantic + z_q_residual
494
+ z_q = self.post_module(z_q)
495
+ z_q = self.upsample(z_q)
496
+ return z_q
497
+
498
+
499
+ # --------------------------------------------------------------------
500
+ # Transformer stack
501
+ # --------------------------------------------------------------------
502
+
503
+ @dataclass
504
+ class ModelArgs:
505
+ block_size: int = 2048
506
+ n_layer: int = 8
507
+ n_head: int = 8
508
+ dim: int = 512
509
+ intermediate_size: int = 1536
510
+ n_local_heads: int = -1
511
+ head_dim: int = 64
512
+ rope_base: float = 10000
513
+ norm_eps: float = 1e-5
514
+ dropout_rate: float = 0.1
515
+ attn_dropout_rate: float = 0.1
516
+ channels_first: bool = True # to be compatible with conv1d input/output
517
+ pos_embed_type: str = "rope" # "rope" or "conformer"
518
+ max_relative_position: int = 128
519
+
520
+ def __post_init__(self):
521
+ if self.n_local_heads == -1:
522
+ self.n_local_heads = self.n_head
523
+ if self.intermediate_size is None:
524
+ hidden_dim = 4 * self.dim
525
+ n_hidden = int(2 * hidden_dim / 3)
526
+ self.intermediate_size = find_multiple(n_hidden, 256)
527
+ assert self.pos_embed_type in ["rope", "conformer"]
528
+
529
+
530
+ class KVCache(nn.Module):
531
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
532
+ super().__init__()
533
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
534
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
535
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
536
+
537
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
538
+ # input_pos: [S], k_val: [B, H, S, D]
539
+ assert input_pos.shape[0] == k_val.shape[2]
540
+ k_out = self.k_cache
541
+ v_out = self.v_cache
542
+ k_out[:, :, input_pos] = k_val
543
+ v_out[:, :, input_pos] = v_val
544
+ return (
545
+ k_out[:, :, : input_pos.max() + 1, :],
546
+ v_out[:, :, : input_pos.max() + 1, :],
547
+ )
548
+
549
+ def clear_cache(self, prompt_len: int):
550
+ self.k_cache[:, :, prompt_len:, :].fill_(0)
551
+ self.v_cache[:, :, prompt_len:, :].fill_(0)
552
+
553
+
554
+ class Transformer(nn.Module):
555
+ def __init__(self, config: ModelArgs) -> None:
556
+ super().__init__()
557
+ self.config = config
558
+
559
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
560
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
561
+
562
+ if config.pos_embed_type == "rope":
563
+ freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base)
564
+ self.register_buffer("freqs_cis", freqs_cis)
565
+ else:
566
+ self.register_buffer("freqs_cis", None)
567
+
568
+ causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool))
569
+ self.register_buffer("causal_mask", causal_mask)
570
+
571
+ self.max_batch_size = -1
572
+ self.max_seq_length = -1
573
+ self.use_kv_cache = False
574
+
575
+ def setup_caches(self, max_batch_size, max_seq_length):
576
+ head_dim = self.config.dim // self.config.n_head
577
+ max_seq_length = find_multiple(max_seq_length, 8)
578
+ self.max_seq_length = max_seq_length
579
+ self.max_batch_size = max_batch_size
580
+ dtype = self.norm.weight.dtype
581
+ device = self.norm.weight.device
582
+
583
+ for b in self.layers:
584
+ b.attention.kv_cache = KVCache(
585
+ max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype
586
+ ).to(device)
587
+
588
+ self.use_kv_cache = True
589
+
590
+ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, mask: Optional[Tensor] = None) -> Tensor:
591
+ if self.config.pos_embed_type == "rope":
592
+ assert self.freqs_cis is not None
593
+ freqs_cis = self.freqs_cis[input_pos]
594
+ else:
595
+ freqs_cis = None
596
+
597
+ if mask is None:
598
+ if not self.training and self.use_kv_cache:
599
+ mask = self.causal_mask[None, None, input_pos]
600
+ mask = mask[..., : input_pos.max() + 1]
601
+ else:
602
+ mask = self.causal_mask[None, None, input_pos]
603
+ mask = mask[..., input_pos]
604
+
605
+ for layer in self.layers:
606
+ x = layer(x, input_pos, freqs_cis, mask)
607
+ x = self.norm(x)
608
+ return x
609
+
610
+
611
+ class TransformerBlock(nn.Module):
612
+ def __init__(self, config: ModelArgs) -> None:
613
+ super().__init__()
614
+ self.attention = Attention(config)
615
+ self.feed_forward = FeedForward(config)
616
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
617
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
618
+ self.attention_layer_scale = LayerScale(config.dim, inplace=True)
619
+ self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
620
+
621
+ def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
622
+ h = x + self.attention_layer_scale(
623
+ self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
624
+ )
625
+ out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
626
+ return out
627
+
628
+
629
+ class Attention(nn.Module):
630
+ def __init__(self, config: ModelArgs):
631
+ super().__init__()
632
+ assert config.dim % config.n_head == 0
633
+
634
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
635
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
636
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
637
+ self.kv_cache = None
638
+
639
+ self.n_head = config.n_head
640
+ self.head_dim = config.head_dim
641
+ self.n_local_heads = config.n_local_heads
642
+ self.dim = config.dim
643
+ self.attn_dropout_rate = config.attn_dropout_rate
644
+ self.pos_embed_type = config.pos_embed_type
645
+
646
+ if self.pos_embed_type == "conformer":
647
+ self.max_relative_position = config.max_relative_position
648
+ num_pos_embeddings = 2 * config.max_relative_position + 1
649
+ self.rel_pos_embeddings = nn.Parameter(torch.zeros(num_pos_embeddings, self.head_dim))
650
+ nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
651
+
652
+ def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
653
+ positions = torch.arange(seqlen, device=q.device)
654
+ relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
655
+ relative_positions = torch.clamp(relative_positions + self.max_relative_position,
656
+ 0, 2 * self.max_relative_position)
657
+ rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
658
+ q = q.transpose(1, 2) # [B, S, H, D]
659
+ rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
660
+ rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
661
+ return rel_logits
662
+
663
+ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
664
+ bsz, seqlen, _ = x.shape
665
+
666
+ kv_size = self.n_local_heads * self.head_dim
667
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
668
+ context_seqlen = seqlen
669
+
670
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
671
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
672
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
673
+
674
+ if self.pos_embed_type == "rope":
675
+ q = apply_rotary_emb(q, freqs_cis)
676
+ k = apply_rotary_emb(k, freqs_cis)
677
+
678
+ q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
679
+
680
+ if self.kv_cache is not None:
681
+ k, v = self.kv_cache.update(input_pos, k, v)
682
+
683
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
684
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
685
+
686
+ if self.pos_embed_type == "conformer":
687
+ scale = 1.0 / math.sqrt(self.head_dim)
688
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
689
+ rel_scores = self._compute_conformer_pos_scores(q, seqlen)
690
+ scores = scores + rel_scores
691
+ if mask is not None:
692
+ scores = scores.masked_fill(~mask, float("-inf"))
693
+ attn = F.softmax(scores, dim=-1)
694
+ if self.attn_dropout_rate > 0 and self.training:
695
+ attn = F.dropout(attn, p=self.attn_dropout_rate)
696
+ y = torch.matmul(attn, v)
697
+ else:
698
+ y = F.scaled_dot_product_attention(
699
+ q, k, v,
700
+ dropout_p=self.attn_dropout_rate if self.training else 0.0,
701
+ attn_mask=mask,
702
+ )
703
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
704
+ y = self.wo(y)
705
+ return y
706
+
707
+
708
+ class FeedForward(nn.Module):
709
+ def __init__(self, config: ModelArgs) -> None:
710
+ super().__init__()
711
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
712
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
713
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
714
+ self.dropout = nn.Dropout(config.dropout_rate)
715
+
716
+ def forward(self, x: Tensor) -> Tensor:
717
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
718
+
719
+
720
+ class RMSNorm(nn.Module):
721
+ def __init__(self, dim: int, eps: float = 1e-5):
722
+ super().__init__()
723
+ self.eps = eps
724
+ self.weight = nn.Parameter(torch.ones(dim))
725
+
726
+ def _norm(self, x):
727
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
728
+
729
+ def forward(self, x: Tensor) -> Tensor:
730
+ output = self._norm(x.float()).type_as(x)
731
+ return output * self.weight
732
+
733
+
734
+ class LayerScale(nn.Module):
735
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-2, inplace: bool = False) -> None:
736
+ super().__init__()
737
+ self.inplace = inplace
738
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
739
+
740
+ def forward(self, x: Tensor) -> Tensor:
741
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
742
+
743
+
744
+ class WindowLimitedTransformer(Transformer):
745
+ """Transformer with window-limited causal attention."""
746
+ def __init__(
747
+ self,
748
+ config: ModelArgs,
749
+ input_dim: int = 512,
750
+ window_size: Optional[int] = None,
751
+ causal: bool = True,
752
+ look_ahead_conv: Optional[nn.Module] = None,
753
+ ):
754
+ super().__init__(config)
755
+ self.window_size = window_size
756
+ self.causal = causal
757
+ self.channels_first = config.channels_first
758
+ self.look_ahead_conv = look_ahead_conv if look_ahead_conv is not None else nn.Identity()
759
+ self.input_proj = nn.Linear(input_dim, config.dim) if input_dim != config.dim else nn.Identity()
760
+ self.output_proj = nn.Linear(config.dim, input_dim) if input_dim != config.dim else nn.Identity()
761
+
762
+ def make_window_limited_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor:
763
+ if self.causal:
764
+ mask = torch.tril(torch.ones(max_length, max_length))
765
+ row_indices = torch.arange(max_length).view(-1, 1)
766
+ window_size = self.window_size or max_length
767
+ valid_range = (row_indices - window_size + 1).clamp(min=0)
768
+ column_indices = torch.arange(max_length)
769
+ mask = (column_indices >= valid_range) & mask.bool()
770
+ else:
771
+ raise NotImplementedError
772
+ mask = mask.bool()[None, None]
773
+ return mask
774
+
775
+ def make_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor:
776
+ if self.causal:
777
+ mask = torch.tril(torch.ones(max_length, max_length))
778
+ else:
779
+ mask = torch.ones(max_length, max_length)
780
+ mask = mask.bool()[None, None]
781
+ for i, x_len in enumerate(x_lens):
782
+ mask[:x_len, i] = 0
783
+ mask = mask.bool()[None, None]
784
+ return mask
785
+
786
+ def forward(self, x: Tensor, x_lens: Optional[Tensor] = None) -> Tensor:
787
+ if self.channels_first:
788
+ x = x.transpose(1, 2)
789
+ x = self.input_proj(x)
790
+ x = self.look_ahead_conv(x)
791
+ input_pos = torch.arange(x.shape[1], device=x.device)
792
+ max_length = x.shape[1]
793
+ if self.window_size is not None:
794
+ mask = self.make_window_limited_mask(max_length, x_lens)
795
+ else:
796
+ mask = self.make_mask(max_length, x_lens)
797
+ mask = mask.to(x.device)
798
+ x = super().forward(x, input_pos, mask)
799
+ x = self.output_proj(x)
800
+ if self.channels_first:
801
+ x = x.transpose(1, 2)
802
+ return x
803
+
804
+
805
+ def precompute_freqs_cis(
806
+ seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
807
+ ) -> Tensor:
808
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
809
+ t = torch.arange(seq_len, device=freqs.device)
810
+ freqs = torch.outer(t, freqs)
811
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
812
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
813
+ return cache.to(dtype=dtype)
814
+
815
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
816
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
817
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
818
+ x_out2 = torch.stack(
819
+ [
820
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
821
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
822
+ ],
823
+ -1,
824
+ )
825
+ x_out2 = x_out2.flatten(3)
826
+ return x_out2.type_as(x)
827
+
828
+
829
+ def init_weights(m):
830
+ if isinstance(m, nn.Conv1d):
831
+ nn.init.trunc_normal_(m.weight, std=0.02)
832
+ nn.init.constant_(m.bias, 0)
833
+
834
+
835
+ # --------------------------------------------------------------------
836
+ # Top-level AE
837
+ # --------------------------------------------------------------------
838
+
839
+ class EncoderBlock(nn.Module):
840
+ def __init__(
841
+ self,
842
+ dim: int = 16,
843
+ stride: int = 1,
844
+ causal: bool = False,
845
+ n_t_layer: int = 0,
846
+ transformer_general_config=None,
847
+ ):
848
+ super().__init__()
849
+ conv_class = CausalWNConv1d if causal else WNConv1d
850
+ transformer_module = (
851
+ nn.Identity()
852
+ if n_t_layer == 0
853
+ else WindowLimitedTransformer(
854
+ causal=causal,
855
+ input_dim=dim,
856
+ window_size=512,
857
+ config=transformer_general_config(
858
+ n_layer=n_t_layer,
859
+ n_head=dim // 64,
860
+ dim=dim,
861
+ intermediate_size=dim * 3,
862
+ ),
863
+ )
864
+ )
865
+ self.block = nn.Sequential(
866
+ # three multi‑receptive‑field residual units
867
+ ResidualUnit(dim // 2, dilation=1, causal=causal),
868
+ ResidualUnit(dim // 2, dilation=3, causal=causal),
869
+ ResidualUnit(dim // 2, dilation=9, causal=causal),
870
+ Snake1d(dim // 2),
871
+ conv_class(dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)),
872
+ transformer_module,
873
+ )
874
+
875
+ def forward(self, x: Tensor) -> Tensor:
876
+ return self.block(x)
877
+
878
+
879
+ class ResidualUnit(nn.Module):
880
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
881
+ super().__init__()
882
+ conv_class = CausalWNConv1d if causal else WNConv1d
883
+ pad = ((7 - 1) * dilation) // 2
884
+ self.block = nn.Sequential(
885
+ Snake1d(dim),
886
+ conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
887
+ Snake1d(dim),
888
+ conv_class(dim, dim, kernel_size=1),
889
+ )
890
+ self.causal = causal
891
+
892
+ def forward(self, x: Tensor) -> Tensor:
893
+ y = self.block(x)
894
+ pad = x.shape[-1] - y.shape[-1]
895
+ if pad > 0:
896
+ if self.causal:
897
+ x = x[..., :-pad]
898
+ else:
899
+ x = x[..., pad // 2 : -pad // 2]
900
+ return x + y
901
+
902
+
903
+ class Encoder(nn.Module):
904
+ def __init__(
905
+ self,
906
+ d_model: int = 64,
907
+ strides: List[int] = [2, 4, 8, 8],
908
+ d_latent: int = 64,
909
+ n_transformer_layers: List[int] = [0, 0, 4, 4],
910
+ transformer_general_config: Optional[ModelArgs] = None,
911
+ causal: bool = False,
912
+ ):
913
+ super().__init__()
914
+ conv_class = CausalWNConv1d if causal else WNConv1d
915
+ layers: List[nn.Module] = [conv_class(1, d_model, kernel_size=7, padding=3)]
916
+ for stride, n_t_layer in zip(strides, n_transformer_layers):
917
+ d_model *= 2
918
+ layers.append(
919
+ EncoderBlock(
920
+ d_model, stride=stride, causal=causal,
921
+ n_t_layer=n_t_layer, transformer_general_config=transformer_general_config,
922
+ )
923
+ )
924
+ layers += [Snake1d(d_model), conv_class(d_model, d_latent, kernel_size=3, padding=1)]
925
+ self.block = nn.Sequential(*layers)
926
+ self.enc_dim = d_model
927
+
928
+ def forward(self, x: Tensor) -> Tensor:
929
+ return self.block(x)
930
+
931
+
932
+ class DecoderBlock(nn.Module):
933
+ def __init__(
934
+ self,
935
+ input_dim: int = 16,
936
+ output_dim: int = 8,
937
+ stride: int = 1,
938
+ causal: bool = False,
939
+ n_t_layer: int = 0,
940
+ transformer_general_config=None,
941
+ ):
942
+ super().__init__()
943
+ conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
944
+ transformer_module = (
945
+ nn.Identity()
946
+ if n_t_layer == 0
947
+ else WindowLimitedTransformer(
948
+ causal=causal,
949
+ input_dim=input_dim,
950
+ window_size=None,
951
+ config=transformer_general_config(
952
+ n_layer=n_t_layer,
953
+ n_head=input_dim // 64,
954
+ dim=input_dim,
955
+ intermediate_size=input_dim * 3,
956
+ ),
957
+ )
958
+ )
959
+ self.block = nn.Sequential(
960
+ Snake1d(input_dim),
961
+ conv_trans_class(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)),
962
+ ResidualUnit(output_dim, dilation=1, causal=causal),
963
+ ResidualUnit(output_dim, dilation=3, causal=causal),
964
+ ResidualUnit(output_dim, dilation=9, causal=causal),
965
+ )
966
+
967
+ def forward(self, x: Tensor) -> Tensor:
968
+ return self.block(x)
969
+
970
+
971
+ class Decoder(nn.Module):
972
+ def __init__(
973
+ self,
974
+ input_channel: int,
975
+ channels: int,
976
+ rates: List[int],
977
+ d_out: int = 1,
978
+ causal: bool = False,
979
+ n_transformer_layers: List[int] = [0, 0, 0, 0],
980
+ transformer_general_config=None,
981
+ ):
982
+ super().__init__()
983
+ conv_class = CausalWNConv1d if causal else WNConv1d
984
+ layers: List[nn.Module] = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
985
+ for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
986
+ input_dim = channels // 2**i
987
+ output_dim = channels // 2 ** (i + 1)
988
+ layers.append(
989
+ DecoderBlock(
990
+ input_dim, output_dim, stride, causal=causal,
991
+ n_t_layer=n_t_layer, transformer_general_config=transformer_general_config,
992
+ )
993
+ )
994
+ layers += [Snake1d(output_dim), conv_class(output_dim, d_out, kernel_size=7, padding=3), nn.Tanh()]
995
+ self.model = nn.Sequential(*layers)
996
+
997
+ def forward(self, x: Tensor) -> Tensor:
998
+ return self.model(x)
999
+
1000
+
1001
+ class DAC(nn.Module):
1002
+ def __init__(
1003
+ self,
1004
+ encoder_dim: int = 64,
1005
+ encoder_rates: List[int] = [2, 4, 8, 8],
1006
+ latent_dim: Optional[int] = None,
1007
+ decoder_dim: int = 1536,
1008
+ decoder_rates: List[int] = [8, 8, 4, 2],
1009
+ quantizer: Optional[nn.Module] = None,
1010
+ sample_rate: int = 44100,
1011
+ causal: bool = True,
1012
+ encoder_transformer_layers: List[int] = [0, 0, 0, 0],
1013
+ decoder_transformer_layers: List[int] = [0, 0, 0, 0],
1014
+ transformer_general_config=None,
1015
+ ):
1016
+ super().__init__()
1017
+
1018
+ self.encoder_dim = encoder_dim
1019
+ self.encoder_rates = encoder_rates
1020
+ self.decoder_dim = decoder_dim
1021
+ self.decoder_rates = decoder_rates
1022
+ self.sample_rate = sample_rate
1023
+
1024
+ if latent_dim is None:
1025
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
1026
+ self.latent_dim = latent_dim
1027
+
1028
+ self.hop_length = int(np.prod(encoder_rates))
1029
+ self.encoder = Encoder(
1030
+ encoder_dim, encoder_rates, latent_dim, causal=causal,
1031
+ n_transformer_layers=encoder_transformer_layers,
1032
+ transformer_general_config=transformer_general_config,
1033
+ )
1034
+ self.quantizer = quantizer
1035
+ self.decoder = Decoder(
1036
+ latent_dim, decoder_dim, decoder_rates, causal=causal,
1037
+ n_transformer_layers=decoder_transformer_layers,
1038
+ transformer_general_config=transformer_general_config,
1039
+ )
1040
+ self.sample_rate = sample_rate
1041
+ self.apply(init_weights)
1042
+
1043
+ self.delay = self.get_delay()
1044
+ self.frame_length = self.hop_length * 4
1045
+
1046
+ def get_output_length(self, input_length: int) -> int:
1047
+ length = input_length
1048
+ for stride in self.encoder_rates:
1049
+ length = math.ceil(length / stride)
1050
+ return length
1051
+
1052
+ def get_delay(self) -> int:
1053
+ l_out = self.get_output_length(0)
1054
+ L = l_out
1055
+
1056
+ layers = [layer for layer in self.modules() if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d))]
1057
+ for layer in reversed(layers):
1058
+ d = layer.dilation[0]
1059
+ k = layer.kernel_size[0]
1060
+ s = layer.stride[0]
1061
+ if isinstance(layer, nn.ConvTranspose1d):
1062
+ L = ((L - d * (k - 1) - 1) / s) + 1
1063
+ elif isinstance(layer, nn.Conv1d):
1064
+ L = (L - 1) * s + d * (k - 1) + 1
1065
+ L = math.ceil(L)
1066
+
1067
+ l_in = L
1068
+ return (l_in - l_out) // 2
1069
+
1070
+ def preprocess(self, audio_data: Tensor, sample_rate: Optional[int]) -> Tensor:
1071
+ if sample_rate is None:
1072
+ sample_rate = self.sample_rate
1073
+ assert sample_rate == self.sample_rate
1074
+
1075
+ length = audio_data.shape[-1]
1076
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
1077
+ audio_data = F.pad(audio_data, (0, right_pad))
1078
+ return audio_data
1079
+
1080
+ def encode(
1081
+ self,
1082
+ audio_data: Tensor,
1083
+ audio_lengths: Optional[Tensor] = None,
1084
+ n_quantizers: Optional[int] = None,
1085
+ **kwargs,
1086
+ ):
1087
+ """Encode audio to quantized code indices."""
1088
+ if audio_data.ndim == 2:
1089
+ audio_data = audio_data.unsqueeze(1)
1090
+ length = audio_data.shape[-1]
1091
+ right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
1092
+ audio_data = F.pad(audio_data, (0, right_pad))
1093
+ if audio_lengths is None:
1094
+ audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
1095
+
1096
+ z = self.encoder(audio_data)
1097
+ vq_results = self.quantizer(z, n_quantizers, **kwargs)
1098
+ indices = vq_results.codes
1099
+ indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
1100
+ return indices, indices_lens
1101
+
1102
+ def decode(self, indices: Tensor, feature_lengths: Tensor):
1103
+ """Decode code indices to audio."""
1104
+ if indices.ndim == 2:
1105
+ indices = indices[None]
1106
+ z = self.quantizer.decode(indices)
1107
+ audio_lengths = feature_lengths * self.frame_length
1108
+ return self.decoder(z), audio_lengths
1109
+
1110
+ def encode_to_codes(self, audio: Tensor, audio_lengths: Optional[Tensor] = None, n_quantizers: Optional[int] = None, **kw):
1111
+ return self.encode(audio, audio_lengths, n_quantizers, **kw)
1112
+
1113
+ def decode_codes(self, indices: Tensor, feature_lengths: Tensor):
1114
+ return self.decode(indices, feature_lengths)
1115
+
1116
+ @torch.no_grad()
1117
+ def encode_zq(self, audio_data: Tensor) -> Tensor:
1118
+ indices, _ = self.encode(audio_data)
1119
+ new_indices = torch.zeros_like(indices)
1120
+ new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.quantizer.semantic_quantizer.codebook_size - 1)
1121
+ new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.quantizer.codebook_size - 1)
1122
+
1123
+ z_q_semantic = self.quantizer.semantic_quantizer.from_codes(new_indices[:, :1])[0]
1124
+ z_q_residual = self.quantizer.quantizer.from_codes(new_indices[:, 1:])[0]
1125
+ z_q = z_q_semantic + z_q_residual
1126
+ return z_q
1127
+
1128
+ @torch.no_grad()
1129
+ def decode_zq(self, z_q: Tensor) -> Tensor:
1130
+ z_q = self.quantizer.post_module(z_q)
1131
+ z_q = self.quantizer.upsample(z_q)
1132
+ return self.decoder(z_q)
1133
+
1134
+ @property
1135
+ def device(self) -> torch.device: return next(self.parameters()).device
1136
+
1137
+ @property
1138
+ def dtype(self) -> torch.dtype: return next(self.parameters()).dtype
1139
+
1140
+ # --------------------------------------------------------------------
1141
+ # Build helpers
1142
+ # --------------------------------------------------------------------
1143
+
1144
+ def build_ae(**cfg) -> DAC:
1145
+ """
1146
+ Factory used by external loaders
1147
+ """
1148
+ # Shared transformer config for the RVQ pre/post modules
1149
+ q_config = ModelArgs(
1150
+ block_size=4096, n_layer=8, n_head=16, dim=1024,
1151
+ intermediate_size=3072, head_dim=64, norm_eps=1e-5,
1152
+ dropout_rate=0.1, attn_dropout_rate=0.1, channels_first=True
1153
+ )
1154
+
1155
+ def make_transformer():
1156
+ return WindowLimitedTransformer(
1157
+ causal=True, window_size=128, input_dim=1024, config=q_config
1158
+ )
1159
+
1160
+ quantizer = DownsampleResidualVectorQuantize(
1161
+ input_dim=1024, n_codebooks=9, codebook_size=1024, codebook_dim=8,
1162
+ quantizer_dropout=0.5, downsample_factor=(2, 2),
1163
+ semantic_codebook_size=4096,
1164
+ pre_module=make_transformer(),
1165
+ post_module=make_transformer(),
1166
+ )
1167
+
1168
+ def transformer_general_config(**kw):
1169
+ return ModelArgs(
1170
+ block_size=kw.get("block_size", 16384),
1171
+ n_layer=kw.get("n_layer", 8),
1172
+ n_head=kw.get("n_head", 8),
1173
+ dim=kw.get("dim", 512),
1174
+ intermediate_size=kw.get("intermediate_size", 1536),
1175
+ n_local_heads=kw.get("n_local_heads", -1),
1176
+ head_dim=kw.get("head_dim", 64),
1177
+ rope_base=kw.get("rope_base", 10000),
1178
+ norm_eps=kw.get("norm_eps", 1e-5),
1179
+ dropout_rate=kw.get("dropout_rate", 0.1),
1180
+ attn_dropout_rate=kw.get("attn_dropout_rate", 0.1),
1181
+ channels_first=kw.get("channels_first", True),
1182
+ )
1183
+
1184
+ dac = DAC(
1185
+ encoder_dim=64, encoder_rates=[2, 4, 8, 8], latent_dim=1024,
1186
+ decoder_dim=1536, decoder_rates=[8, 8, 4, 2],
1187
+ quantizer=quantizer, sample_rate=44100, causal=True,
1188
+ encoder_transformer_layers=[0, 0, 0, 4],
1189
+ decoder_transformer_layers=[4, 0, 0, 0],
1190
+ transformer_general_config=transformer_general_config,
1191
+ )
1192
+ return dac
1193
+
1194
+ __all__ = [
1195
+ "DAC",
1196
+ "build_ae",
1197
+ "VectorQuantize",
1198
+ "ResidualVectorQuantize",
1199
+ "DownsampleResidualVectorQuantize",
1200
+ ]
1201
+
1202
+
1203
+ # ----- BEGIN DAC MIT LICENSE -----
1204
+ # MIT License
1205
+ # Copyright (c) 2023-present, Descript
1206
+ #
1207
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
1208
+ # of this software and associated documentation files (the "Software"), to deal
1209
+ # in the Software without restriction, including without limitation the rights
1210
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
1211
+ # copies of the Software, and to permit persons to whom the Software is
1212
+ # furnished to do so, subject to the following conditions:
1213
+ #
1214
+ # The above copyright notice and this permission notice shall be included in all
1215
+ # copies or substantial portions of the Software.
1216
+ #
1217
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1218
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1219
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1220
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1221
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1222
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1223
+ # SOFTWARE.
1224
+ # ----- END DAC MIT LICENSE -----
1225
+
gradio_app.py ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import time
5
+ import secrets
6
+ import logging
7
+ import warnings
8
+ from pathlib import Path
9
+ from typing import Tuple, Any
10
+ from functools import partial
11
+
12
+ # see lines ~40-50 for running on lower VRAM GPUs
13
+
14
+ logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
15
+
16
+ import gradio as gr
17
+ import torch
18
+ import torchaudio
19
+
20
+ from inference import (
21
+ load_model_from_hf,
22
+ load_fish_ae_from_hf,
23
+ load_pca_state_from_hf,
24
+ load_audio,
25
+ ae_reconstruct,
26
+ sample_pipeline,
27
+ compile_model,
28
+ compile_fish_ae,
29
+ sample_euler_cfg_independent_guidances
30
+ )
31
+
32
+ # --------------------------------------------------------------------
33
+ # IF ON 8GB VRAM GPU, SET FISH_AE_DTYPE to bfloat16 and DEFAULT_SAMPLE_LATENT_LENGTH to < 640 (e.g., 576)
34
+
35
+ # Configuration
36
+ MODEL_DTYPE = torch.bfloat16
37
+ FISH_AE_DTYPE = torch.float32
38
+ # FISH_AE_DTYPE = torch.bfloat16 # USE THIS IF OOM ON 8GB vram GPU
39
+
40
+ DEFAULT_SAMPLE_LATENT_LENGTH = 640 # decrease if OOM on 8GB vram GPU
41
+ # DEFAULT_SAMPLE_LATENT_LENGTH = 576 # (example, ~27 seconds rather than ~30; can change depending on what fits in VRAM)
42
+
43
+ # NOTE peak S1-DAC decoding VRAM > peak latent sampling VRAM, so decoding in chunks (which is posisble as S1-DAC is causal) would allow for full 640-length generation on lower VRAM GPUs
44
+
45
+ # --------------------------------------------------------------------
46
+
47
+ # Audio Prompt Library for Custom Audio Panel (included in repo)
48
+ AUDIO_PROMPT_FOLDER = Path("./audio_prompts")
49
+
50
+ # --------------------------------------------------------------------
51
+
52
+ TEXT_PRESETS_PATH = Path("./text_presets.txt")
53
+ SAMPLER_PRESETS_PATH = Path("./sampler_presets.json")
54
+
55
+ TEMP_AUDIO_DIR = Path("./temp_gradio_audio")
56
+ TEMP_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
57
+
58
+ # --------------------------------------------------------------------
59
+ # Model loading (eager for local use)
60
+ model = load_model_from_hf(dtype=MODEL_DTYPE, delete_blockwise_modules=True)
61
+ fish_ae = load_fish_ae_from_hf(dtype=FISH_AE_DTYPE)
62
+ pca_state = load_pca_state_from_hf()
63
+
64
+ model_compiled = None
65
+ fish_ae_compiled = None
66
+
67
+ # --------------------------------------------------------------------
68
+ # Helper functions
69
+ def make_stem(prefix: str, user_id: str | None = None) -> str:
70
+ """Create unique filename stem: prefix__user__timestamp_random or prefix__timestamp_random if no user_id."""
71
+ ts = int(time.time() * 1000)
72
+ rand = secrets.token_hex(4)
73
+ if user_id:
74
+ return f"{prefix}__{user_id}__{ts}_{rand}"
75
+ return f"{prefix}__{ts}_{rand}"
76
+
77
+
78
+ def cleanup_temp_audio(dir_: Path, user_id: str | None, max_age_sec: int = 60 * 5):
79
+ """Remove old files globally and all previous files for this user."""
80
+ now = time.time()
81
+
82
+ for p in dir_.glob("*"):
83
+ try:
84
+ if p.is_file() and (now - p.stat().st_mtime) > max_age_sec:
85
+ p.unlink(missing_ok=True)
86
+ except Exception:
87
+ pass
88
+
89
+ if user_id:
90
+ for p in dir_.glob(f"*__{user_id}__*"):
91
+ try:
92
+ if p.is_file():
93
+ p.unlink(missing_ok=True)
94
+ except Exception:
95
+ pass
96
+
97
+
98
+ def save_audio_with_format(audio_tensor: torch.Tensor, base_path: Path, filename: str, sample_rate: int, audio_format: str) -> Path:
99
+ """Save audio in specified format, fallback to WAV if MP3 encoding fails."""
100
+ if audio_format == "mp3":
101
+ try:
102
+ output_path = base_path / f"{filename}.mp3"
103
+ torchaudio.save(
104
+ str(output_path),
105
+ audio_tensor,
106
+ sample_rate,
107
+ format="mp3",
108
+ encoding="mp3",
109
+ bits_per_sample=None,
110
+ )
111
+ return output_path
112
+ except Exception as e:
113
+ print(f"MP3 encoding failed: {e}, falling back to WAV")
114
+ output_path = base_path / f"{filename}.wav"
115
+ torchaudio.save(str(output_path), audio_tensor, sample_rate)
116
+ return output_path
117
+
118
+ output_path = base_path / f"{filename}.wav"
119
+ torchaudio.save(str(output_path), audio_tensor, sample_rate)
120
+ return output_path
121
+
122
+
123
+ def to_bool(val: Any) -> bool:
124
+ """Parse truthy values from common string/bool inputs."""
125
+ return str(val).strip().lower() not in {"", "0", "false", "off", "none", "no"}
126
+
127
+
128
+ def find_min_bucket_gte(values_str: str, actual_length: int) -> int | None:
129
+ """Parse comma-separated values and find minimum value >= actual_length.
130
+
131
+ If a single value is provided (no comma), returns that value directly.
132
+ If comma-separated, finds the smallest bucket that can fit the content.
133
+ Returns None if empty string.
134
+ """
135
+ if not values_str or not values_str.strip():
136
+ return None
137
+
138
+ values_str = values_str.strip()
139
+
140
+ # Single value case - return as-is
141
+ if "," not in values_str:
142
+ return int(values_str)
143
+
144
+ # Multiple values - find minimum >= actual_length
145
+ values = [int(v.strip()) for v in values_str.split(",") if v.strip()]
146
+ if not values:
147
+ return None
148
+
149
+ # Find minimum value >= actual_length
150
+ candidates = [v for v in values if v >= actual_length]
151
+ if candidates:
152
+ return min(candidates)
153
+
154
+ # If no value is >=, return the maximum (best effort)
155
+ return max(values)
156
+
157
+
158
+ def generate_audio(
159
+ text_prompt: str,
160
+ speaker_audio_path: str,
161
+ num_steps: int,
162
+ rng_seed: int,
163
+ cfg_scale_text: float,
164
+ cfg_scale_speaker: float,
165
+ cfg_min_t: float,
166
+ cfg_max_t: float,
167
+ truncation_factor: float,
168
+ rescale_k: float,
169
+ rescale_sigma: float,
170
+ force_speaker: bool,
171
+ speaker_kv_scale: float,
172
+ speaker_kv_min_t: float,
173
+ speaker_kv_max_layers: int,
174
+ reconstruct_first_30_seconds: bool,
175
+ use_custom_shapes: bool,
176
+ max_text_byte_length: str,
177
+ max_speaker_latent_length: str,
178
+ sample_latent_length: str,
179
+ audio_format: str,
180
+ use_compile: bool,
181
+ show_original_audio: bool,
182
+ session_id: str,
183
+ ) -> Tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any]:
184
+ """Generate audio using the model."""
185
+ global model_compiled, fish_ae_compiled
186
+
187
+ if use_compile:
188
+ if model_compiled is None:
189
+ try:
190
+ model_compiled = compile_model(model)
191
+ fish_ae_compiled = compile_fish_ae(fish_ae)
192
+ except Exception as e:
193
+ print(f"Compilation wrapping failed: {str(e)}")
194
+ model_compiled = None
195
+ fish_ae_compiled = None
196
+ use_compile = False
197
+
198
+ active_model = model_compiled if use_compile else model
199
+ active_fish_ae = fish_ae_compiled if use_compile else fish_ae
200
+
201
+ cleanup_temp_audio(TEMP_AUDIO_DIR, session_id)
202
+
203
+ start_time = time.time()
204
+
205
+ num_steps_int = min(max(int(num_steps), 1), 80)
206
+ rng_seed_int = int(rng_seed) if rng_seed is not None else 0
207
+ cfg_scale_text_val = float(cfg_scale_text)
208
+ cfg_scale_speaker_val = float(cfg_scale_speaker) if cfg_scale_speaker is not None else None
209
+ cfg_min_t_val = float(cfg_min_t)
210
+ cfg_max_t_val = float(cfg_max_t)
211
+ truncation_factor_val = float(truncation_factor)
212
+ rescale_k_val = float(rescale_k) if rescale_k != 1.0 else None
213
+ rescale_sigma_val = float(rescale_sigma)
214
+
215
+ speaker_kv_enabled = bool(force_speaker)
216
+ if speaker_kv_enabled:
217
+ speaker_kv_scale_val = float(speaker_kv_scale) if speaker_kv_scale is not None else None
218
+ speaker_kv_min_t_val = float(speaker_kv_min_t) if speaker_kv_min_t is not None else None
219
+ speaker_kv_max_layers_val = int(speaker_kv_max_layers) if speaker_kv_max_layers is not None else None
220
+ else:
221
+ speaker_kv_scale_val = None
222
+ speaker_kv_min_t_val = None
223
+ speaker_kv_max_layers_val = None
224
+
225
+ # Load speaker audio early so we can compute actual lengths for bucket selection
226
+ use_zero_speaker = not speaker_audio_path or speaker_audio_path == ""
227
+ speaker_audio = load_audio(speaker_audio_path).cuda() if not use_zero_speaker else None
228
+
229
+ if use_custom_shapes:
230
+ # Compute actual text byte length
231
+ actual_text_byte_length = len(text_prompt.encode("utf-8")) + 1 # +1 for BOS token
232
+
233
+ # Compute actual speaker latent length (audio_samples // 2048)
234
+ AE_DOWNSAMPLE_FACTOR = 2048
235
+ if speaker_audio is not None:
236
+ actual_speaker_latent_length = (speaker_audio.shape[-1] // AE_DOWNSAMPLE_FACTOR) // 4 * 4
237
+ else:
238
+ actual_speaker_latent_length = 0
239
+
240
+ # Find appropriate bucket sizes from comma-separated values
241
+ pad_to_max_text_length = find_min_bucket_gte(max_text_byte_length, actual_text_byte_length)
242
+ pad_to_max_speaker_latent_length = find_min_bucket_gte(max_speaker_latent_length, actual_speaker_latent_length)
243
+ sample_latent_length_val = int(sample_latent_length) if sample_latent_length.strip() else (DEFAULT_SAMPLE_LATENT_LENGTH or 640)
244
+ else:
245
+ pad_to_max_text_length = None
246
+ pad_to_max_speaker_latent_length = None
247
+ sample_latent_length_val = DEFAULT_SAMPLE_LATENT_LENGTH or 640
248
+
249
+
250
+ sample_fn = partial(
251
+ sample_euler_cfg_independent_guidances,
252
+ num_steps=num_steps_int,
253
+ cfg_scale_text=cfg_scale_text_val,
254
+ cfg_scale_speaker=cfg_scale_speaker_val,
255
+ cfg_min_t=cfg_min_t_val,
256
+ cfg_max_t=cfg_max_t_val,
257
+ truncation_factor=truncation_factor_val,
258
+ rescale_k=rescale_k_val,
259
+ rescale_sigma=rescale_sigma_val,
260
+ speaker_kv_scale=speaker_kv_scale_val,
261
+ speaker_kv_min_t=speaker_kv_min_t_val,
262
+ speaker_kv_max_layers=speaker_kv_max_layers_val,
263
+ sequence_length=sample_latent_length_val,
264
+ )
265
+
266
+ audio_out, normalized_text = sample_pipeline(
267
+ model=active_model,
268
+ fish_ae=active_fish_ae,
269
+ pca_state=pca_state,
270
+ sample_fn=sample_fn,
271
+ text_prompt=text_prompt,
272
+ speaker_audio=speaker_audio,
273
+ rng_seed=rng_seed_int,
274
+ pad_to_max_text_length=pad_to_max_text_length,
275
+ pad_to_max_speaker_latent_length=pad_to_max_speaker_latent_length,
276
+ normalize_text=True,
277
+ )
278
+
279
+ audio_to_save = audio_out[0].cpu()
280
+
281
+ stem = make_stem("generated", session_id)
282
+ output_path = save_audio_with_format(audio_to_save, TEMP_AUDIO_DIR, stem, 44100, audio_format)
283
+
284
+ generation_time = time.time() - start_time
285
+ time_str = f"⏱️ Total generation time: {generation_time:.2f}s"
286
+ text_display = f"**Text Prompt (normalized):**\n\n{normalized_text}"
287
+
288
+ recon_output_path = None
289
+ original_output_path = None
290
+
291
+ if reconstruct_first_30_seconds and speaker_audio is not None:
292
+ audio_recon = ae_reconstruct(
293
+ fish_ae=fish_ae,
294
+ pca_state=pca_state,
295
+ audio=torch.nn.functional.pad(
296
+ speaker_audio[..., :2048 * 640],
297
+ (0, max(0, 2048 * 640 - speaker_audio.shape[-1])),
298
+ )[None],
299
+ )[..., : speaker_audio.shape[-1]]
300
+
301
+ recon_stem = make_stem("speaker_recon", session_id)
302
+ recon_output_path = save_audio_with_format(audio_recon.cpu()[0], TEMP_AUDIO_DIR, recon_stem, 44100, audio_format)
303
+
304
+ if show_original_audio and speaker_audio is not None:
305
+ original_stem = make_stem("original_audio", session_id)
306
+ original_output_path = save_audio_with_format(speaker_audio.cpu(), TEMP_AUDIO_DIR, original_stem, 44100, audio_format)
307
+
308
+ show_reference_section = (show_original_audio or reconstruct_first_30_seconds) and speaker_audio is not None
309
+ return (
310
+ gr.update(),
311
+ gr.update(value=str(output_path), visible=True),
312
+ gr.update(value=text_display, visible=True),
313
+ gr.update(value=str(original_output_path) if original_output_path else None, visible=True),
314
+ gr.update(value=time_str, visible=True),
315
+ gr.update(value=str(recon_output_path) if recon_output_path else None, visible=True),
316
+ gr.update(visible=(show_original_audio and speaker_audio is not None)),
317
+ gr.update(visible=(reconstruct_first_30_seconds and speaker_audio is not None)),
318
+ gr.update(visible=show_reference_section),
319
+ )
320
+
321
+
322
+ # UI Helper Functions
323
+ def load_text_presets():
324
+ """Load text presets from file with category and word count."""
325
+ if TEXT_PRESETS_PATH.exists():
326
+ with open(TEXT_PRESETS_PATH, "r", encoding="utf-8") as f:
327
+ lines = [line.strip() for line in f if line.strip()]
328
+
329
+ result = []
330
+ for line in lines:
331
+ if " | " in line:
332
+ parts = line.split(" | ", 1)
333
+ category = parts[0]
334
+ text = parts[1]
335
+ else:
336
+ category = "Uncategorized"
337
+ text = line
338
+
339
+ word_count = len(text.split())
340
+ result.append([category, str(word_count), text])
341
+
342
+ return result
343
+ return []
344
+
345
+
346
+ def select_text_preset(evt: gr.SelectData):
347
+ """Handle text preset selection - extract text from the row."""
348
+ if evt.value:
349
+ if isinstance(evt.index, (tuple, list)) and len(evt.index) >= 2:
350
+ row_index = evt.index[0]
351
+ else:
352
+ row_index = evt.index
353
+
354
+ presets_data = load_text_presets()
355
+ if isinstance(row_index, int) and row_index < len(presets_data):
356
+ text = presets_data[row_index][2]
357
+ return gr.update(value=text)
358
+ return gr.update()
359
+
360
+
361
+ def toggle_mode(mode):
362
+ """Toggle advanced settings section visibility."""
363
+ show_advanced = mode == "Advanced Mode"
364
+ return gr.update(visible=show_advanced)
365
+
366
+
367
+ def update_force_row(force_speaker):
368
+ """Show KV scaling controls when Force Speaker is enabled."""
369
+ return gr.update(visible=bool(force_speaker))
370
+
371
+
372
+ def apply_cfg_preset(preset_name):
373
+ """Apply CFG guidance preset."""
374
+ presets = {
375
+ "higher speaker": (3.0, 8.0, 0.5, 1.0),
376
+ "large guidances": (8.0, 8.0, 0.5, 1.0),
377
+ }
378
+
379
+ if preset_name not in presets:
380
+ return [gr.update()] * 5
381
+
382
+ text_scale, speaker_scale, min_t, max_t = presets[preset_name]
383
+
384
+ return [
385
+ gr.update(value=text_scale),
386
+ gr.update(value=speaker_scale),
387
+ gr.update(value=min_t),
388
+ gr.update(value=max_t),
389
+ gr.update(value="Custom"),
390
+ ]
391
+
392
+
393
+ def apply_speaker_kv_preset(preset_name):
394
+ """Apply speaker KV attention control preset."""
395
+ if preset_name == "enable":
396
+ return [
397
+ gr.update(value=True),
398
+ gr.update(visible=True),
399
+ gr.update(value="Custom"),
400
+ ]
401
+ if preset_name == "off":
402
+ return [
403
+ gr.update(value=False),
404
+ gr.update(visible=False),
405
+ gr.update(value="Custom"),
406
+ ]
407
+ return [gr.update()] * 3
408
+
409
+
410
+ def apply_truncation_preset(preset_name):
411
+ """Apply truncation & temporal rescaling preset."""
412
+ presets = {
413
+ "flat": (0.8, 1.2, 3.0),
414
+ "sharp": (0.9, 0.96, 3.0),
415
+ "baseline(sharp)": (1.0, 1.0, 3.0),
416
+ }
417
+
418
+ if preset_name == "custom" or preset_name not in presets:
419
+ return [gr.update()] * 4
420
+
421
+ truncation, rescale_k, rescale_sigma = presets[preset_name]
422
+
423
+ return [
424
+ gr.update(value=truncation),
425
+ gr.update(value=rescale_k),
426
+ gr.update(value=rescale_sigma),
427
+ gr.update(value="Custom"),
428
+ ]
429
+
430
+
431
+ def load_sampler_presets():
432
+ """Load sampler presets from JSON file."""
433
+ if SAMPLER_PRESETS_PATH.exists():
434
+ with open(SAMPLER_PRESETS_PATH, "r") as f:
435
+ return json.load(f)
436
+
437
+ default_presets = {
438
+ "Independent-High-Speaker-CFG": {
439
+ "num_steps": "40",
440
+ "cfg_scale_text": "3.0",
441
+ "cfg_scale_speaker": "8.0",
442
+ "cfg_min_t": "0.5",
443
+ "cfg_max_t": "1.0",
444
+ "truncation_factor": "1.",
445
+ "rescale_k": "1.",
446
+ "rescale_sigma": "3.0"
447
+ }
448
+ }
449
+ with open(SAMPLER_PRESETS_PATH, "w") as f:
450
+ json.dump(default_presets, f, indent=2)
451
+ return default_presets
452
+
453
+
454
+ def apply_sampler_preset(preset_name):
455
+ """Apply a sampler preset to all fields."""
456
+ presets = load_sampler_presets()
457
+ if preset_name == "Custom" or preset_name not in presets:
458
+ return [gr.update()] * 13
459
+
460
+ preset = presets[preset_name]
461
+ speaker_kv_enabled = to_bool(preset.get("speaker_kv_enable", False))
462
+
463
+ def to_num(val, default):
464
+ try:
465
+ return float(val) if isinstance(val, str) else val
466
+ except (ValueError, TypeError):
467
+ return default
468
+
469
+ return [
470
+ gr.update(value=int(to_num(preset.get("num_steps", "40"), 40))),
471
+ gr.update(value=to_num(preset.get("cfg_scale_text", "3.0"), 3.0)),
472
+ gr.update(value=to_num(preset.get("cfg_scale_speaker", "5.0"), 5.0)),
473
+ gr.update(value=to_num(preset.get("cfg_min_t", "0.5"), 0.5)),
474
+ gr.update(value=to_num(preset.get("cfg_max_t", "1.0"), 1.0)),
475
+ gr.update(value=to_num(preset.get("truncation_factor", "0.8"), 0.8)),
476
+ gr.update(value=to_num(preset.get("rescale_k", "1.2"), 1.2)),
477
+ gr.update(value=to_num(preset.get("rescale_sigma", "3.0"), 3.0)),
478
+ gr.update(value=speaker_kv_enabled),
479
+ gr.update(visible=speaker_kv_enabled),
480
+ gr.update(value=to_num(preset.get("speaker_kv_scale", "1.5"), 1.5)),
481
+ gr.update(value=to_num(preset.get("speaker_kv_min_t", "0.9"), 0.9)),
482
+ gr.update(value=int(to_num(preset.get("speaker_kv_max_layers", "24"), 24))),
483
+ ]
484
+
485
+
486
+ AUDIO_EXTS = {".wav", ".mp3", ".m4a", ".ogg", ".flac", ".webm", ".aac", ".opus"}
487
+
488
+
489
+ def get_audio_prompt_files(search_query: str = ""):
490
+ """Get list of audio files from the audio prompt folder, optionally filtered by search query."""
491
+ if AUDIO_PROMPT_FOLDER is None or not AUDIO_PROMPT_FOLDER.exists():
492
+ return []
493
+
494
+ files = sorted([f.name for f in AUDIO_PROMPT_FOLDER.iterdir() if f.is_file() and f.suffix.lower() in AUDIO_EXTS], key=str.lower)
495
+
496
+ # Filter by search query if provided
497
+ if search_query.strip():
498
+ query_lower = search_query.lower()
499
+ files = [f for f in files if query_lower in f.lower()]
500
+
501
+ return [[file] for file in files]
502
+
503
+
504
+ def filter_audio_prompts(search_query: str):
505
+ """Filter audio prompts based on search query."""
506
+ return gr.update(value=get_audio_prompt_files(search_query))
507
+
508
+
509
+ def select_audio_prompt_file(evt: gr.SelectData):
510
+ """Handle audio prompt file selection from table."""
511
+ if evt.value and AUDIO_PROMPT_FOLDER is not None:
512
+ file_path = AUDIO_PROMPT_FOLDER / evt.value
513
+ if file_path.exists():
514
+ return gr.update(value=str(file_path))
515
+ return gr.update()
516
+
517
+
518
+ # UI styling and helpers
519
+ LINK_CSS = """
520
+ .preset-inline { display:flex; align-items:baseline; gap:6px; margin-top:-4px; margin-bottom:-12px; }
521
+ .preset-inline .title { font-weight:600; font-size:.95rem; }
522
+ .preset-inline .dim { color:#666; margin:0 4px; }
523
+ a.preset-link { color: #0a5bd8; text-decoration: underline; cursor: pointer; font-weight: 400; }
524
+ a.preset-link:hover { text-decoration: none; opacity: 0.8; }
525
+ .dark a.preset-link,
526
+ [data-theme="dark"] a.preset-link { color: #60a5fa !important; }
527
+ .dark a.preset-link:hover,
528
+ [data-theme="dark"] a.preset-link:hover { color: #93c5fd !important; }
529
+ .dark .preset-inline .dim,
530
+ [data-theme="dark"] .preset-inline .dim { color: #9ca3af !important; }
531
+ .proxy-btn { position:absolute; width:0; height:0; overflow:hidden; padding:0 !important; margin:0 !important; border:0 !important; opacity:0; pointer-events:none; }
532
+ .gr-group { border: 1px solid #d1d5db !important; background: #f3f4f6 !important; }
533
+ .dark .gr-group,
534
+ [data-theme="dark"] .gr-group { border: 1px solid #4b5563 !important; background: #1f2937 !important; }
535
+ .generated-audio-player { border: 3px solid #667eea !important; border-radius: 12px !important; padding: 20px !important; background: linear-gradient(135deg, rgba(102, 126, 234, 0.08) 0%, rgba(118, 75, 162, 0.05) 100%) !important; box-shadow: 0 4px 12px rgba(102, 126, 234, 0.2) !important; margin: 1rem 0 !important; }
536
+ .generated-audio-player > div { background: transparent !important; }
537
+ #component-mode-selector { text-align: center; padding: 1rem 0; }
538
+ #component-mode-selector label { font-size: 1.1rem !important; font-weight: 600 !important; margin-bottom: 0.5rem !important; }
539
+ #component-mode-selector .wrap { justify-content: center !important; }
540
+ #component-mode-selector fieldset { border: 2px solid #e5e7eb !important; border-radius: 8px !important; padding: 1rem !important; background: #f9fafb !important; }
541
+ .dark #component-mode-selector fieldset,
542
+ [data-theme="dark"] #component-mode-selector fieldset { border: 2px solid #4b5563 !important; background: #1f2937 !important; }
543
+ .section-separator { height: 3px !important; background: linear-gradient(90deg, transparent 0%, #667eea 20%, #764ba2 80%, transparent 100%) !important; border: none !important; margin: 2rem 0 !important; }
544
+ .dark .section-separator,
545
+ [data-theme="dark"] .section-separator { background: linear-gradient(90deg, transparent 0%, #667eea 20%, #764ba2 80%, transparent 100%) !important; }
546
+ .gradio-container h1,
547
+ .gradio-container h2 { font-weight: 700 !important; margin-top: 1.5rem !important; margin-bottom: 1rem !important; }
548
+ .tip-box { background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%) !important; border-left: 4px solid #f59e0b !important; border-radius: 8px !important; padding: 1rem 1.5rem !important; margin: 1rem 0 !important; box-shadow: 0 2px 4px rgba(245, 158, 11, 0.1) !important; }
549
+ .tip-box strong { color: #92400e !important; }
550
+ .dark .tip-box,
551
+ [data-theme="dark"] .tip-box { background: linear-gradient(135deg, #451a03 0%, #78350f 100%) !important; border-left: 4px solid #f59e0b !important; }
552
+ .dark .tip-box strong,
553
+ [data-theme="dark"] .tip-box strong { color: #fbbf24 !important; }
554
+ """
555
+
556
+ JS_CODE = r"""
557
+ function () {
558
+ const appEl = document.querySelector("gradio-app");
559
+ const root = appEl && appEl.shadowRoot ? appEl.shadowRoot : document;
560
+ function clickHiddenButtonById(id) {
561
+ if (!id) return;
562
+ const host = root.getElementById(id);
563
+ if (!host) return;
564
+ const realBtn = host.querySelector("button, [role='button']") || host;
565
+ realBtn.click();
566
+ }
567
+ root.addEventListener("click", (ev) => {
568
+ const a = ev.target.closest("a.preset-link");
569
+ if (!a) return;
570
+ ev.preventDefault();
571
+ ev.stopPropagation();
572
+ ev.stopImmediatePropagation();
573
+ clickHiddenButtonById(a.getAttribute("data-fire"));
574
+ return false;
575
+ }, true);
576
+ }
577
+ """
578
+
579
+
580
+ def init_session():
581
+ """Initialize session ID for this browser tab/session."""
582
+ return secrets.token_hex(8)
583
+
584
+
585
+ with gr.Blocks(title="Echo-TTS", css=LINK_CSS, js=JS_CODE) as demo:
586
+ gr.Markdown("# Echo-TTS")
587
+ gr.Markdown("*Jordan Darefsky, 2025. See technical details [here](https://jordandarefsky.com/blog/2025/echo/)*")
588
+
589
+ gr.Markdown("**License Notice:** All audio outputs are subject to non-commercial use [CC-BY-NC-SA-4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/).")
590
+
591
+ gr.Markdown("**Responsible Use:** Do not use this model to impersonate real people without their explicit consent or to generate deceptive audio.")
592
+
593
+ with gr.Accordion("πŸ“– Quick Start Instructions", open=True):
594
+ gr.Markdown(
595
+ """
596
+ 1. Upload or record a short reference clip (or leave blank for no speaker reference).
597
+ 2. Pick a text preset or type your own prompt.
598
+ 3. Click **Generate Audio**.
599
+
600
+ <div class="tip-box">
601
+ πŸ’‘ **Tip:** If the generated voice does not match the reference, enable "Force Speaker" and regenerate.
602
+ </div>
603
+ """
604
+ )
605
+
606
+ session_id_state = gr.State(None)
607
+
608
+ gr.Markdown("# Speaker Reference")
609
+ with gr.Row():
610
+ if AUDIO_PROMPT_FOLDER is not None and AUDIO_PROMPT_FOLDER.exists():
611
+ with gr.Column(scale=1, min_width=200):
612
+ gr.Markdown("#### Audio Library (click to load)")
613
+ audio_prompt_search = gr.Textbox(
614
+ label="",
615
+ placeholder="πŸ” Search audio prompts...",
616
+ lines=1,
617
+ max_lines=1,
618
+ )
619
+ audio_prompt_table = gr.Dataframe(
620
+ value=get_audio_prompt_files(),
621
+ headers=["Filename"],
622
+ datatype=["str"],
623
+ row_count=(10, "dynamic"),
624
+ col_count=(1, "fixed"),
625
+ interactive=False,
626
+ label="",
627
+ )
628
+ with gr.Column(scale=2):
629
+ custom_audio_input = gr.Audio(
630
+ sources=["upload", "microphone"],
631
+ type="filepath",
632
+ label="Speaker Reference Audio (first five minutes used; blank for no speaker reference)",
633
+ max_length=600,
634
+ )
635
+
636
+ gr.HTML('<hr class="section-separator">')
637
+ gr.Markdown("# Text Prompt")
638
+ with gr.Accordion("Text Presets", open=True):
639
+ text_presets_table = gr.Dataframe(
640
+ value=load_text_presets(),
641
+ headers=["Category", "Words", "Preset Text"],
642
+ datatype=["str", "str", "str"],
643
+ row_count=(3, "dynamic"),
644
+ col_count=(3, "fixed"),
645
+ interactive=False,
646
+ column_widths=["12%", "6%", "82%"],
647
+ )
648
+ text_prompt = gr.Textbox(label="Text Prompt", placeholder="[S1] Enter your text prompt here...", lines=4)
649
+
650
+ gr.HTML('<hr class="section-separator">')
651
+ gr.Markdown("# Generation")
652
+
653
+ with gr.Row():
654
+ with gr.Column(scale=1):
655
+ pass
656
+ with gr.Column(scale=2):
657
+ mode_selector = gr.Radio(
658
+ choices=["Simple Mode", "Advanced Mode"],
659
+ value="Simple Mode",
660
+ label="",
661
+ info=None,
662
+ elem_id="component-mode-selector",
663
+ )
664
+ with gr.Column(scale=1):
665
+ pass
666
+
667
+ with gr.Accordion("βš™οΈ Generation Parameters", open=True):
668
+ with gr.Row(equal_height=False):
669
+ presets = load_sampler_presets()
670
+ preset_keys = list(presets.keys())
671
+ first_preset = preset_keys[0] if preset_keys else "Custom"
672
+
673
+ with gr.Column(scale=2):
674
+ preset_dropdown = gr.Dropdown(
675
+ choices=["Custom"] + preset_keys,
676
+ value=first_preset,
677
+ label="Sampler Preset",
678
+ info="Load preset configurations",
679
+ )
680
+
681
+ with gr.Column(scale=0.8, min_width=100):
682
+ num_steps = gr.Number(
683
+ label="Steps",
684
+ value=40,
685
+ info="Sampling steps (Try 20-80)",
686
+ precision=0,
687
+ minimum=5,
688
+ step=5,
689
+ maximum=80,
690
+ )
691
+
692
+ with gr.Column(scale=0.8, min_width=100):
693
+ rng_seed = gr.Number(label="RNG Seed", value=0, info="Seed for noise", precision=0)
694
+
695
+ with gr.Column(scale=3):
696
+ with gr.Group():
697
+ gr.HTML(
698
+ """
699
+ <div class="preset-inline">
700
+ <span class="title">Speaker KV Attention Scaling</span>
701
+ </div>
702
+ """
703
+ )
704
+ spk_kv_preset_enable = gr.Button("", elem_id="spk_kv_enable", elem_classes=["proxy-btn"])
705
+ spk_kv_preset_off = gr.Button("", elem_id="spk_kv_off", elem_classes=["proxy-btn"])
706
+ force_speaker = gr.Checkbox(
707
+ label='"Force Speaker" (KV scaling)',
708
+ value=False,
709
+ info="Enable to more strongly match the reference speaker (though higher values may degrade quality)",
710
+ )
711
+ with gr.Row(visible=False) as speaker_kv_row:
712
+ speaker_kv_scale = gr.Number(label="KV Scale", value=1.5, info="Scale factor (>1 -> larger effect; try 1.5, 1.2, ...)", minimum=0, step=0.1)
713
+ speaker_kv_min_t = gr.Number(
714
+ label="KV Min t",
715
+ value=0.9,
716
+ info="(0-1), scale applied from steps t=1. to val",
717
+ minimum=0,
718
+ maximum=1,
719
+ step=0.05,
720
+ )
721
+ speaker_kv_max_layers = gr.Number(
722
+ label="Max Layers",
723
+ value=24,
724
+ info="(0-24), scale applied in first N layers",
725
+ precision=0,
726
+ minimum=0,
727
+ maximum=24,
728
+ )
729
+
730
+ with gr.Column(visible=False) as advanced_mode_column:
731
+ compile_checkbox = gr.Checkbox(
732
+ label="Compile Model",
733
+ value=False,
734
+ info="Compile for faster runs (~10-30% faster); forces Custom Shapes on to avoid excessive recompilation.",
735
+ )
736
+ use_custom_shapes_checkbox = gr.Checkbox(
737
+ label="Use Custom Shapes (Advanced)",
738
+ value=False,
739
+ info="Override default generation length and/or force latent and text padding (if unchecked, no padding is used and latent generation length is 640β‰ˆ30s.)",
740
+ )
741
+
742
+ with gr.Row(visible=False) as custom_shapes_row:
743
+ max_text_byte_length = gr.Textbox(
744
+ label="Max Text Byte Length (padded)",
745
+ value="768",
746
+ info="Single value or comma-separated buckets (auto-selects min >= length); 768 = max; leave blank for no padding",
747
+ scale=1,
748
+ )
749
+ max_speaker_latent_length = gr.Textbox(
750
+ label="Max Speaker Latent Length (padded)",
751
+ value="640, 2816, 6400",
752
+ info="Single value or comma-separated buckets (auto-selects min >= length); 640β‰ˆ30s, 2560β‰ˆ2min, 6400β‰ˆ5min (max); leave blank for no padding",
753
+ scale=1,
754
+ )
755
+ sample_latent_length = gr.Textbox(
756
+ label="Sample Latent Length",
757
+ value=str(DEFAULT_SAMPLE_LATENT_LENGTH),
758
+ info="Maximum sample latent length (640β‰ˆ30s max seen during training; smaller works well for generating prefixes)",
759
+ scale=1,
760
+ )
761
+
762
+ with gr.Row():
763
+ with gr.Column(scale=1):
764
+ with gr.Group():
765
+ gr.HTML(
766
+ """
767
+ <div class="preset-inline">
768
+ <span class="title">Truncation &amp; Temporal Rescaling</span><span class="dim">(</span>
769
+ <a href="javascript:void(0)" class="preset-link" data-fire="trunc_flat">flat</a>
770
+ <span class="dim">,</span>
771
+ <a href="javascript:void(0)" class="preset-link" data-fire="trunc_sharp">sharp</a>
772
+ <span class="dim">,</span>
773
+ <a href="javascript:void(0)" class="preset-link" data-fire="trunc_baseline">baseline(sharp)</a>
774
+ <span class="dim">)</span>
775
+ </div>
776
+ """
777
+ )
778
+ trunc_preset_flat = gr.Button("", elem_id="trunc_flat", elem_classes=["proxy-btn"])
779
+ trunc_preset_sharp = gr.Button("", elem_id="trunc_sharp", elem_classes=["proxy-btn"])
780
+ trunc_preset_baseline = gr.Button("", elem_id="trunc_baseline", elem_classes=["proxy-btn"])
781
+ with gr.Row():
782
+ truncation_factor = gr.Number(
783
+ label="Truncation Factor",
784
+ value=0.8,
785
+ info="Multiply initial noise (<1 helps artifacts)",
786
+ minimum=0,
787
+ step=0.05,
788
+ )
789
+ rescale_k = gr.Number(
790
+ label="Rescale k", value=1.2, info="<1=sharpen, >1=flatten, 1=off", minimum=0, step=0.05
791
+ )
792
+ rescale_sigma = gr.Number(
793
+ label="Rescale Οƒ", value=3.0, info="Sigma parameter", minimum=0, step=0.1
794
+ )
795
+
796
+ with gr.Column(scale=1):
797
+ with gr.Group():
798
+ gr.HTML(
799
+ """
800
+ <div class="preset-inline">
801
+ <span class="title">CFG Guidance</span><span class="dim">(</span>
802
+ <a href="javascript:void(0)" class="preset-link" data-fire="cfg_higher">higher speaker</a>
803
+ <span class="dim">,</span>
804
+ <a href="javascript:void(0)" class="preset-link" data-fire="cfg_large">large guidances</a>
805
+ <span class="dim">)</span>
806
+ </div>
807
+ """
808
+ )
809
+ cfg_preset_higher_speaker = gr.Button("", elem_id="cfg_higher", elem_classes=["proxy-btn"])
810
+ cfg_preset_large_guidances = gr.Button("", elem_id="cfg_large", elem_classes=["proxy-btn"])
811
+ with gr.Row():
812
+ cfg_scale_text = gr.Number(
813
+ label="Text CFG Scale", value=3.0, info="Guidance strength for text", minimum=0, step=0.5
814
+ )
815
+ cfg_scale_speaker = gr.Number(
816
+ label="Speaker CFG Scale",
817
+ value=5.0,
818
+ info="Guidance strength for speaker",
819
+ minimum=0,
820
+ step=0.5,
821
+ )
822
+
823
+ with gr.Row():
824
+ cfg_min_t = gr.Number(
825
+ label="CFG Min t", value=0.5, info="(0-1), CFG applied when t >= val", minimum=0, maximum=1, step=0.05
826
+ )
827
+ cfg_max_t = gr.Number(
828
+ label="CFG Max t", value=1.0, info="(0-1), CFG applied when t <= val", minimum=0, maximum=1, step=0.05
829
+ )
830
+
831
+ with gr.Row(equal_height=True):
832
+ audio_format = gr.Radio(choices=["wav", "mp3"], value="wav", label="Format", scale=1, min_width=90)
833
+ generate_btn = gr.Button("Generate Audio", variant="primary", size="lg", scale=10)
834
+ with gr.Column(scale=1):
835
+ show_original_audio = gr.Checkbox(label="Re-display Original Audio (full 5-minute cropped mono)", value=False)
836
+ reconstruct_first_30_seconds = gr.Checkbox(
837
+ label="Show Autoencoder Reconstruction (only first 30s of reference)", value=False
838
+ )
839
+
840
+ gr.HTML('<hr class="section-separator">')
841
+ with gr.Accordion("Generated Audio", open=True, visible=True) as generated_section:
842
+ generation_time_display = gr.Markdown("", visible=False)
843
+ with gr.Group(elem_classes=["generated-audio-player"]):
844
+ generated_audio = gr.Audio(label="Generated Audio", visible=True)
845
+ text_prompt_display = gr.Markdown("", visible=False)
846
+
847
+ gr.Markdown("---")
848
+ reference_audio_header = gr.Markdown("#### Reference Audio", visible=False)
849
+
850
+ with gr.Accordion("Original Audio (5 min Cropped Mono)", open=False, visible=False) as original_accordion:
851
+ original_audio = gr.Audio(label="Original Reference Audio (5 min)", visible=True)
852
+
853
+ with gr.Accordion("Autoencoder Reconstruction of First 30s of Reference", open=False, visible=False) as reference_accordion:
854
+ reference_audio = gr.Audio(label="Decoded Reference Audio (30s)", visible=True)
855
+
856
+ # Event handlers
857
+ if AUDIO_PROMPT_FOLDER is not None and AUDIO_PROMPT_FOLDER.exists():
858
+ audio_prompt_table.select(select_audio_prompt_file, outputs=[custom_audio_input])
859
+ audio_prompt_search.change(filter_audio_prompts, inputs=[audio_prompt_search], outputs=[audio_prompt_table])
860
+
861
+ text_presets_table.select(select_text_preset, outputs=text_prompt)
862
+
863
+ mode_selector.change(toggle_mode, inputs=[mode_selector], outputs=[advanced_mode_column])
864
+
865
+ force_speaker.change(update_force_row, inputs=[force_speaker], outputs=[speaker_kv_row])
866
+
867
+ def toggle_custom_shapes(enabled):
868
+ return gr.update(visible=enabled)
869
+
870
+ use_custom_shapes_checkbox.change(
871
+ toggle_custom_shapes,
872
+ inputs=[use_custom_shapes_checkbox],
873
+ outputs=[custom_shapes_row],
874
+ )
875
+
876
+ def on_compile_change(compile_enabled):
877
+ """When compile is enabled, force custom shapes to be enabled."""
878
+ if compile_enabled:
879
+ return (
880
+ gr.update(value=True), # use_custom_shapes_checkbox
881
+ gr.update(visible=True), # custom_shapes_row
882
+ )
883
+ return (
884
+ gr.update(),
885
+ gr.update(),
886
+ )
887
+
888
+ compile_checkbox.change(
889
+ on_compile_change,
890
+ inputs=[compile_checkbox],
891
+ outputs=[use_custom_shapes_checkbox, custom_shapes_row],
892
+ )
893
+
894
+ cfg_preset_higher_speaker.click(
895
+ lambda: apply_cfg_preset("higher speaker"), outputs=[cfg_scale_text, cfg_scale_speaker, cfg_min_t, cfg_max_t, preset_dropdown]
896
+ )
897
+ cfg_preset_large_guidances.click(
898
+ lambda: apply_cfg_preset("large guidances"), outputs=[cfg_scale_text, cfg_scale_speaker, cfg_min_t, cfg_max_t, preset_dropdown]
899
+ )
900
+
901
+ spk_kv_preset_enable.click(lambda: apply_speaker_kv_preset("enable"), outputs=[force_speaker, speaker_kv_row, preset_dropdown])
902
+ spk_kv_preset_off.click(lambda: apply_speaker_kv_preset("off"), outputs=[force_speaker, speaker_kv_row, preset_dropdown])
903
+
904
+ trunc_preset_flat.click(lambda: apply_truncation_preset("flat"), outputs=[truncation_factor, rescale_k, rescale_sigma, preset_dropdown])
905
+ trunc_preset_sharp.click(lambda: apply_truncation_preset("sharp"), outputs=[truncation_factor, rescale_k, rescale_sigma, preset_dropdown])
906
+ trunc_preset_baseline.click(
907
+ lambda: apply_truncation_preset("baseline(sharp)"), outputs=[truncation_factor, rescale_k, rescale_sigma, preset_dropdown]
908
+ )
909
+
910
+ preset_dropdown.change(
911
+ apply_sampler_preset,
912
+ inputs=preset_dropdown,
913
+ outputs=[
914
+ num_steps,
915
+ cfg_scale_text,
916
+ cfg_scale_speaker,
917
+ cfg_min_t,
918
+ cfg_max_t,
919
+ truncation_factor,
920
+ rescale_k,
921
+ rescale_sigma,
922
+ force_speaker,
923
+ speaker_kv_row,
924
+ speaker_kv_scale,
925
+ speaker_kv_min_t,
926
+ speaker_kv_max_layers,
927
+ ],
928
+ )
929
+
930
+ generate_btn.click(
931
+ generate_audio,
932
+ inputs=[
933
+ text_prompt,
934
+ custom_audio_input,
935
+ num_steps,
936
+ rng_seed,
937
+ cfg_scale_text,
938
+ cfg_scale_speaker,
939
+ cfg_min_t,
940
+ cfg_max_t,
941
+ truncation_factor,
942
+ rescale_k,
943
+ rescale_sigma,
944
+ force_speaker,
945
+ speaker_kv_scale,
946
+ speaker_kv_min_t,
947
+ speaker_kv_max_layers,
948
+ reconstruct_first_30_seconds,
949
+ use_custom_shapes_checkbox,
950
+ max_text_byte_length,
951
+ max_speaker_latent_length,
952
+ sample_latent_length,
953
+ audio_format,
954
+ compile_checkbox,
955
+ show_original_audio,
956
+ session_id_state,
957
+ ],
958
+ outputs=[
959
+ generated_section,
960
+ generated_audio,
961
+ text_prompt_display,
962
+ original_audio,
963
+ generation_time_display,
964
+ reference_audio,
965
+ original_accordion,
966
+ reference_accordion,
967
+ reference_audio_header,
968
+ ],
969
+ )
970
+
971
+ demo.load(init_session, outputs=[session_id_state]).then(
972
+ lambda: apply_sampler_preset(list(load_sampler_presets().keys())[0]),
973
+ outputs=[
974
+ num_steps,
975
+ cfg_scale_text,
976
+ cfg_scale_speaker,
977
+ cfg_min_t,
978
+ cfg_max_t,
979
+ truncation_factor,
980
+ rescale_k,
981
+ rescale_sigma,
982
+ force_speaker,
983
+ speaker_kv_row,
984
+ speaker_kv_scale,
985
+ speaker_kv_min_t,
986
+ speaker_kv_max_layers,
987
+ ],
988
+ )
989
+
990
+
991
+ if __name__ == "__main__":
992
+ demo.launch(
993
+ allowed_paths=[str(AUDIO_PROMPT_FOLDER)]
994
+ )
inference.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable, List, Tuple
3
+
4
+ from huggingface_hub import hf_hub_download
5
+ import safetensors.torch as st
6
+ import torch
7
+ import torchaudio
8
+ from torchcodec.decoders import AudioDecoder
9
+
10
+ from autoencoder import DAC, build_ae
11
+ from model import EchoDiT
12
+
13
+ def load_model_from_hf(repo_id: str = "jordand/echo-tts-base", device: str = "cuda", dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None, delete_blockwise_modules: bool = False) -> EchoDiT:
14
+ with torch.device("meta"):
15
+ model = EchoDiT(
16
+ latent_size=80, model_size=2048, num_layers=24, num_heads=16,
17
+ intermediate_size=5888, norm_eps=1e-5,
18
+ text_vocab_size=256, text_model_size=1280, text_num_layers=14,
19
+ text_num_heads=10, text_intermediate_size=3328,
20
+ speaker_patch_size=4, speaker_model_size=1280, speaker_num_layers=14,
21
+ speaker_num_heads=10, speaker_intermediate_size=3328,
22
+ timestep_embed_size=512, adaln_rank=256,
23
+ )
24
+ w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token)
25
+ state = st.load_file(w_path, device="cpu")
26
+
27
+ if delete_blockwise_modules:
28
+ state = {k: v for k, v in state.items() if not (
29
+ k.startswith("latent_encoder.") or
30
+ k.startswith("latent_norm") or
31
+ ".wk_latent" in k or
32
+ ".wv_latent" in k
33
+ )}
34
+
35
+ if dtype is not None:
36
+ state = {k: v.to(dtype=dtype) for k, v in state.items()}
37
+
38
+ state = {k: v.to(device=device) for k, v in state.items()}
39
+
40
+ model.load_state_dict(state, strict=False, assign=True)
41
+ model = model.eval()
42
+
43
+ if compile:
44
+ model = compile_model(model)
45
+
46
+ return model
47
+
48
+ def compile_model(model: EchoDiT) -> EchoDiT:
49
+ model = torch.compile(model)
50
+ model.get_kv_cache_text = torch.compile(model.get_kv_cache_text)
51
+ model.get_kv_cache_speaker = torch.compile(model.get_kv_cache_speaker)
52
+ model.get_kv_cache_latent = torch.compile(model.get_kv_cache_latent)
53
+ return model
54
+
55
+ def load_fish_ae_from_hf(repo_id: str = "jordand/fish-s1-dac-min", device: str = "cuda", dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC:
56
+
57
+ with torch.device("meta"):
58
+ fish_ae = build_ae()
59
+
60
+ w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token)
61
+ if dtype is not None and dtype != torch.float32:
62
+ state = st.load_file(w_path, device="cpu")
63
+ state = {k: v.to(dtype=dtype) for k, v in state.items()}
64
+ state = {k: v.to(device=device) for k, v in state.items()}
65
+ fish_ae.load_state_dict(state, strict=False, assign=True)
66
+ else:
67
+ state = st.load_file(w_path, device=device)
68
+ fish_ae.load_state_dict(state, strict=False, assign=True)
69
+
70
+ fish_ae = fish_ae.eval().to(device)
71
+
72
+ if compile:
73
+ fish_ae = compile_fish_ae(fish_ae)
74
+
75
+ return fish_ae
76
+
77
+ def compile_fish_ae(fish_ae: DAC) -> DAC:
78
+ fish_ae.quantizer.upsample = torch.compile(fish_ae.quantizer.upsample)
79
+ fish_ae.quantizer.downsample = torch.compile(fish_ae.quantizer.downsample)
80
+ fish_ae.quantizer.pre_module = torch.compile(fish_ae.quantizer.pre_module)
81
+ fish_ae.quantizer.post_module = torch.compile(fish_ae.quantizer.post_module)
82
+ return fish_ae
83
+
84
+
85
+ @dataclass
86
+ class PCAState:
87
+ pca_components: torch.Tensor
88
+ pca_mean: torch.Tensor
89
+ latent_scale: float
90
+
91
+ def load_pca_state_from_hf(repo_id: str = "jordand/echo-tts-base", device: str = "cuda", filename: str = "pca_state.safetensors", token: str | None = None) -> PCAState:
92
+ p_path = hf_hub_download(repo_id, filename, token=token)
93
+ t = st.load_file(p_path, device=device)
94
+ return PCAState(
95
+ pca_components=t["pca_components"],
96
+ pca_mean=t["pca_mean"],
97
+ latent_scale=float(t["latent_scale"].item()),
98
+ )
99
+
100
+
101
+ # ________
102
+
103
+ def load_audio(path: str, max_duration: int = 300) -> torch.Tensor:
104
+
105
+ decoder = AudioDecoder(path)
106
+ sr = decoder.metadata.sample_rate
107
+ audio = decoder.get_samples_played_in_range(0, max_duration)
108
+ audio = audio.data.mean(dim=0).unsqueeze(0)
109
+ audio = torchaudio.functional.resample(audio, sr, 44_100)
110
+ audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.))
111
+ # is this better than clipping? should we target a specific energy level?
112
+ return audio
113
+
114
+ def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True, return_normalized_text: bool = False) -> torch.Tensor | Tuple[torch.Tensor, str]:
115
+
116
+ if normalize:
117
+ text = text.replace("…", "...")
118
+ text = text.replace('’', "'")
119
+ text = text.replace('”', '"')
120
+ text = text.replace('”', '"')
121
+ text = text.replace("\n", " ")
122
+ text = text.replace(":", ",")
123
+ text = text.replace(";", ",")
124
+ text = text.replace("β€”", ", ")
125
+ if not text.startswith("[") and not text.startswith("(") and 'S1' not in text and 'S2' not in text:
126
+ text = "[S1] " + text
127
+
128
+ b = list(text.encode("utf-8"))
129
+ if append_bos:
130
+ b.insert(0, 0)
131
+
132
+ if return_normalized_text:
133
+ return torch.tensor(b), text
134
+
135
+ return torch.tensor(b)
136
+
137
+ def get_text_input_ids_and_mask(text_arr: List[str], max_length: int | None, device: str | None = None, normalize: bool = True, return_normalized_text: bool = False, pad_to_max: bool = True) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor, torch.Tensor, List[str]]:
138
+ encoded_texts = [tokenizer_encode(text, normalize=normalize, return_normalized_text=True) for text in text_arr]
139
+
140
+ if max_length is None:
141
+ max_length = max(len(enc) for enc, _ in encoded_texts)
142
+
143
+ tokens = torch.zeros((len(text_arr), max_length), dtype=torch.int32)
144
+ mask = torch.zeros((len(text_arr), max_length), dtype=torch.bool)
145
+
146
+ for i, (encoded, _) in enumerate(encoded_texts):
147
+ length = min(len(encoded), max_length)
148
+ tokens[i, :length] = encoded[:length]
149
+ mask[i, :length] = 1
150
+
151
+ if not pad_to_max and max_length is not None:
152
+ tokens, mask = tokens[:, :max_length], mask[:, :max_length]
153
+
154
+ if device is not None:
155
+ tokens, mask = tokens.to(device), mask.to(device)
156
+
157
+ if return_normalized_text:
158
+ return tokens, mask, [text for _, text in encoded_texts]
159
+ return tokens, mask
160
+
161
+ # ________
162
+
163
+ @torch.inference_mode()
164
+ def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
165
+ assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length)
166
+ z_q = fish_ae.encode_zq(audio).float()
167
+ z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T
168
+ z_q = z_q * pca_state.latent_scale
169
+ return z_q
170
+
171
+ @torch.inference_mode()
172
+ def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor:
173
+ z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean
174
+ return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float()
175
+
176
+ @torch.inference_mode()
177
+ def ae_reconstruct(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
178
+ assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length)
179
+ z_q = ae_encode(fish_ae, pca_state, audio.to(fish_ae.dtype))
180
+ return ae_decode(fish_ae, pca_state, z_q)
181
+
182
+ # ________
183
+
184
+ @torch.inference_mode()
185
+ def get_speaker_latent_and_mask(
186
+ fish_ae: DAC,
187
+ pca_state: PCAState,
188
+ audio: torch.Tensor, # (1, length)
189
+ max_speaker_latent_length: int = 6400, # pretrained max length
190
+ audio_chunk_size: int = 640 * 2048, # (~30 seconds, 1/10 max speaker condition size; max chunk seen in training)
191
+ pad_to_max: bool = False,
192
+ divis_by_patch_size: int | None = 4,
193
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
194
+ # gets speaker latent and mask from audio, computes in chunks and concatenates (similar to training setup)
195
+
196
+ AE_DOWNSAMPLE_FACTOR = 2048
197
+ max_audio_len_length = max_speaker_latent_length * AE_DOWNSAMPLE_FACTOR
198
+
199
+ assert audio.ndim == 2 and audio.shape[0] == 1 # (1, length)
200
+ audio = audio[:, :max_audio_len_length]
201
+
202
+ latent_arr = []
203
+
204
+ for i in range(0, audio.shape[1], audio_chunk_size):
205
+ audio_chunk = audio[:, i:i + audio_chunk_size]
206
+ if audio_chunk.shape[1] < audio_chunk_size:
207
+ audio_chunk = torch.nn.functional.pad(audio_chunk, (0, audio_chunk_size - audio_chunk.shape[1]))
208
+
209
+ latent_chunk = ae_encode(fish_ae, pca_state, audio_chunk.unsqueeze(0))
210
+ latent_arr.append(latent_chunk)
211
+
212
+ speaker_latent = torch.cat(latent_arr, dim=1)
213
+
214
+ actual_latent_length = audio.shape[1] // AE_DOWNSAMPLE_FACTOR
215
+ speaker_mask = (torch.arange(speaker_latent.shape[1], device=speaker_latent.device) < actual_latent_length).unsqueeze(0)
216
+
217
+ if pad_to_max and speaker_latent.shape[1] < max_speaker_latent_length:
218
+ speaker_latent = torch.nn.functional.pad(speaker_latent, (0, 0, 0, max_speaker_latent_length - speaker_latent.shape[1]))
219
+ speaker_mask = torch.nn.functional.pad(speaker_mask, (0, max_speaker_latent_length - speaker_mask.shape[1]))
220
+ elif not pad_to_max:
221
+ speaker_latent = speaker_latent[:, :actual_latent_length]
222
+ speaker_mask = speaker_mask[:, :actual_latent_length]
223
+
224
+ if divis_by_patch_size is not None:
225
+ speaker_latent = speaker_latent[:, :speaker_latent.shape[1] // divis_by_patch_size * divis_by_patch_size]
226
+ speaker_mask = speaker_mask[:, :speaker_mask.shape[1] // divis_by_patch_size * divis_by_patch_size]
227
+
228
+ return speaker_latent, speaker_mask
229
+
230
+
231
+ # ________
232
+
233
+ def find_flattening_point(data, target_value=0.0, window_size=20, std_threshold=0.05):
234
+ # simple heuristic to find end of latent generations; slow and can be improved
235
+ # (data is (length, 80))
236
+ padded_data = torch.cat([data, torch.zeros(window_size, *data.shape[1:], device=data.device, dtype=data.dtype)])
237
+ for i in range(len(padded_data) - window_size):
238
+ window = padded_data[i:i + window_size]
239
+ if window.std() < std_threshold and abs(window.mean() - target_value) < 0.1:
240
+ return i
241
+ return len(data)
242
+
243
+ def crop_audio_to_flattening_point(audio: torch.Tensor, latent: torch.Tensor) -> torch.Tensor:
244
+ # (audio is (..., length), latent is (length, 80))
245
+ flattening_point = find_flattening_point(latent)
246
+ return audio[..., :flattening_point * 2048]
247
+
248
+ SampleFn = Callable[
249
+ [EchoDiT, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int],
250
+ torch.Tensor
251
+ ]
252
+
253
+ @torch.inference_mode()
254
+ def sample_pipeline(
255
+ model: EchoDiT,
256
+ fish_ae: DAC,
257
+ pca_state: PCAState,
258
+ sample_fn: SampleFn,
259
+ text_prompt: str,
260
+ speaker_audio: torch.Tensor | None,
261
+ rng_seed: int,
262
+ pad_to_max_speaker_latent_length: int | None = None,
263
+ pad_to_max_text_length: int | None = None,
264
+ normalize_text: bool = True,
265
+ ) -> Tuple[torch.Tensor, str]:
266
+
267
+ MAX_SPEAKER_LATENT_LENGTH = 6400 # max seen during training, though maybe can go higher?
268
+ MAX_TEXT_LENGTH = 768
269
+
270
+ device, dtype = model.device, model.dtype
271
+
272
+ text_input_ids, text_mask, normalized_text = get_text_input_ids_and_mask([text_prompt], max_length=min(pad_to_max_text_length or MAX_TEXT_LENGTH, MAX_TEXT_LENGTH), device=device, normalize=normalize_text, return_normalized_text=True, pad_to_max=(pad_to_max_text_length is not None))
273
+
274
+ if speaker_audio is None:
275
+ speaker_latent = torch.zeros((1, pad_to_max_speaker_latent_length or 4, 80), device=device, dtype=dtype)
276
+ speaker_mask = torch.zeros((1, pad_to_max_speaker_latent_length or 4), device=device, dtype=torch.bool)
277
+ else:
278
+ speaker_latent, speaker_mask = get_speaker_latent_and_mask(
279
+ fish_ae,
280
+ pca_state,
281
+ speaker_audio.to(fish_ae.dtype).to(device),
282
+ max_speaker_latent_length=pad_to_max_speaker_latent_length or MAX_SPEAKER_LATENT_LENGTH,
283
+ pad_to_max=(pad_to_max_speaker_latent_length is not None)
284
+ )
285
+
286
+ latent_out = sample_fn(model, speaker_latent, speaker_mask, text_input_ids, text_mask, rng_seed)
287
+
288
+ audio_out = ae_decode(fish_ae, pca_state, latent_out)
289
+
290
+ audio_out = crop_audio_to_flattening_point(audio_out, latent_out[0])
291
+
292
+ return audio_out, normalized_text[0]
293
+
294
+
295
+
296
+
297
+ # ________
298
+
299
+
300
+ KVCache = List[Tuple[torch.Tensor, torch.Tensor]]
301
+
302
+ def _concat_kv_caches(*caches: KVCache) -> KVCache:
303
+ # helper that concatenates multiple KV caches along the batch dimension
304
+ num_layers = len(caches[0])
305
+ result = []
306
+ for i in range(num_layers):
307
+ k = torch.cat([c[i][0] for c in caches], dim=0)
308
+ v = torch.cat([c[i][1] for c in caches], dim=0)
309
+ result.append((k, v))
310
+ return result
311
+
312
+ def _multiply_kv_cache(cache: KVCache, scale: float, max_layers: int | None = None) -> None:
313
+ # helper that multiplies KV cache values in-place, for kv speaker scaling
314
+ num_layers = len(cache) if max_layers is None else min(max_layers, len(cache))
315
+ for i in range(num_layers):
316
+ k, v = cache[i]
317
+ k.mul_(scale)
318
+ v.mul_(scale)
319
+
320
+ def _temporal_score_rescale(
321
+ v_pred: torch.Tensor, x_t: torch.Tensor, t: float, rescale_k: float, rescale_sigma: float
322
+ ) -> torch.Tensor:
323
+ # for https://arxiv.org/pdf/2510.01184
324
+ if t < 1:
325
+ snr = (1 - t) ** 2 / (t ** 2)
326
+ ratio = (snr * rescale_sigma ** 2 + 1) / (snr * rescale_sigma ** 2 / rescale_k + 1)
327
+ return 1 / (1 - t) * (ratio * ((1 - t) * v_pred + x_t) - x_t)
328
+ return v_pred
329
+
330
+
331
+ @torch.inference_mode()
332
+ def sample_euler_cfg_independent_guidances(
333
+ model: EchoDiT,
334
+ speaker_latent: torch.Tensor,
335
+ speaker_mask: torch.Tensor,
336
+ text_input_ids: torch.Tensor,
337
+ text_mask: torch.Tensor,
338
+ rng_seed: int,
339
+ num_steps: int,
340
+ cfg_scale_text: float,
341
+ cfg_scale_speaker: float,
342
+ cfg_min_t: float,
343
+ cfg_max_t: float,
344
+ truncation_factor: float | None,
345
+ rescale_k: float | None,
346
+ rescale_sigma: float | None,
347
+ speaker_kv_scale: float | None,
348
+ speaker_kv_max_layers: int | None,
349
+ speaker_kv_min_t: float | None,
350
+ sequence_length: int | None = None,
351
+ ) -> torch.Tensor:
352
+
353
+ if sequence_length is None:
354
+ sequence_length = 640 # max sequence length during training
355
+
356
+ INIT_SCALE = 0.999 # so that we can apply rescale to first step
357
+
358
+ device, dtype = model.device, model.dtype
359
+ batch_size = text_input_ids.shape[0]
360
+
361
+ rng = torch.Generator(device=device).manual_seed(rng_seed)
362
+
363
+ t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
364
+
365
+ text_mask_uncond = torch.zeros_like(text_mask)
366
+ speaker_mask_uncond = torch.zeros_like(speaker_mask)
367
+
368
+ kv_text_cond = model.get_kv_cache_text(text_input_ids, text_mask)
369
+ kv_speaker_cond = model.get_kv_cache_speaker(speaker_latent.to(dtype))
370
+
371
+ if speaker_kv_scale is not None:
372
+ _multiply_kv_cache(kv_speaker_cond, speaker_kv_scale, speaker_kv_max_layers)
373
+
374
+ # masks prevent decoder from attending to unconds:
375
+ kv_text_full = _concat_kv_caches(kv_text_cond, kv_text_cond, kv_text_cond)
376
+ kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
377
+
378
+ full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
379
+ full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
380
+
381
+ x_t = torch.randn((batch_size, sequence_length, 80), device=device, dtype=torch.float32, generator=rng)
382
+ if truncation_factor is not None:
383
+ x_t = x_t * truncation_factor
384
+
385
+ for i in range(num_steps):
386
+ t, t_next = t_schedule[i], t_schedule[i + 1]
387
+
388
+ has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
389
+
390
+ if has_cfg:
391
+ v_cond, v_uncond_text, v_uncond_speaker = model(
392
+ x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
393
+ t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
394
+ text_mask=full_text_mask,
395
+ speaker_mask=full_speaker_mask,
396
+ kv_cache_text=kv_text_full,
397
+ kv_cache_speaker=kv_speaker_full,
398
+ ).float().chunk(3, dim=0)
399
+ v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker) # can also use a single, joint unconditional for fewer NFE
400
+ else:
401
+ v_pred = model(
402
+ x=x_t.to(dtype),
403
+ t=(torch.ones((batch_size,), device=device) * t).to(dtype),
404
+ text_mask=text_mask,
405
+ speaker_mask=speaker_mask,
406
+ kv_cache_text=kv_text_cond,
407
+ kv_cache_speaker=kv_speaker_cond,
408
+ ).float()
409
+
410
+ # optional temporal score rescaling: https://arxiv.org/pdf/2510.01184
411
+ if rescale_k is not None and rescale_sigma is not None:
412
+ v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
413
+
414
+ # optional kv speaker scaling
415
+ if speaker_kv_scale is not None and t_next < speaker_kv_min_t and t >= speaker_kv_min_t:
416
+ _multiply_kv_cache(kv_speaker_cond, 1. / speaker_kv_scale, speaker_kv_max_layers)
417
+ kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
418
+
419
+ x_t = x_t + v_pred * (t_next - t)
420
+
421
+ return x_t
422
+
423
+
424
+
425
+ # ___________________________________________________________
426
+ # simple example
427
+
428
+ if __name__ == "__main__":
429
+ model = load_model_from_hf(delete_blockwise_modules=True)
430
+ fish_ae = load_fish_ae_from_hf()
431
+ pca_state = load_pca_state_from_hf()
432
+
433
+ speaker_audio_path = "/path/to/speaker/audio.wav"
434
+ speaker_audio = load_audio(speaker_audio_path).cuda()
435
+ speaker_latent, speaker_mask = get_speaker_latent_and_mask(fish_ae, pca_state, speaker_audio)
436
+
437
+ text = "[S1] Alright, I'm going to demo this new model called Echo TTS. Hopefully this works, I'm super excited to try this and see what it can do."
438
+ text_input_ids, text_mask = get_text_input_ids_and_mask([text], max_length=None, device="cuda")
439
+
440
+ latent_out = sample_euler_cfg_independent_guidances(
441
+ model=model,
442
+ speaker_latent=speaker_latent,
443
+ speaker_mask=speaker_mask,
444
+ text_input_ids=text_input_ids,
445
+ text_mask=text_mask,
446
+ rng_seed=0,
447
+ num_steps=40,
448
+ cfg_scale_text=3.0,
449
+ cfg_scale_speaker=8.0,
450
+ cfg_min_t=0.5,
451
+ cfg_max_t=1.0,
452
+ truncation_factor=0.8,
453
+ rescale_k=None,
454
+ rescale_sigma=None,
455
+ speaker_kv_scale=None,
456
+ speaker_kv_max_layers=None,
457
+ speaker_kv_min_t=None,
458
+ sequence_length=640, # (max 640. shorter lengths will generate prefixes, not necessarily full generations)
459
+ )
460
+ audio_out = ae_decode(fish_ae, pca_state, latent_out)
461
+ audio_out = crop_audio_to_flattening_point(audio_out, latent_out[0])
462
+ torchaudio.save("output.wav", audio_out[0].cpu(), 44100)
inference_blockwise.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+
5
+ from inference import (
6
+ KVCache,
7
+ _concat_kv_caches,
8
+ _multiply_kv_cache,
9
+ _temporal_score_rescale,
10
+ )
11
+ from model import EchoDiT
12
+
13
+
14
+ @torch.inference_mode()
15
+ def sample_blockwise_euler_cfg_independent_guidances(
16
+ model: EchoDiT,
17
+ speaker_latent: torch.Tensor,
18
+ speaker_mask: torch.Tensor,
19
+ text_input_ids: torch.Tensor,
20
+ text_mask: torch.Tensor,
21
+ rng_seed: int,
22
+ block_sizes: List[int],
23
+ num_steps: int,
24
+ cfg_scale_text: float,
25
+ cfg_scale_speaker: float,
26
+ cfg_min_t: float,
27
+ cfg_max_t: float,
28
+ truncation_factor: float | None,
29
+ rescale_k: float | None,
30
+ rescale_sigma: float | None,
31
+ speaker_kv_scale: float | None,
32
+ speaker_kv_max_layers: int | None,
33
+ speaker_kv_min_t: float | None,
34
+ continuation_latent: torch.Tensor | None = None,
35
+ ) -> torch.Tensor:
36
+
37
+ INIT_SCALE = 0.999 # so that we can apply rescale to first step
38
+
39
+ device, dtype = model.device, model.dtype
40
+ batch_size = text_input_ids.shape[0]
41
+
42
+ rng = torch.Generator(device=device).manual_seed(rng_seed)
43
+
44
+ t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
45
+
46
+ text_mask_uncond = torch.zeros_like(text_mask)
47
+ speaker_mask_uncond = torch.zeros_like(speaker_mask)
48
+
49
+ kv_text_cond = model.get_kv_cache_text(text_input_ids, text_mask)
50
+ kv_speaker_cond = model.get_kv_cache_speaker(speaker_latent.to(dtype))
51
+
52
+ # masks prevent decoder from attending to unconds:
53
+ kv_text_full = _concat_kv_caches(kv_text_cond, kv_text_cond, kv_text_cond)
54
+ kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
55
+
56
+ full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
57
+ full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
58
+
59
+ prefix_latent = torch.zeros((batch_size, sum(block_sizes) , 80), device=device, dtype=torch.float32)
60
+
61
+ start_pos = 0
62
+ if continuation_latent is not None:
63
+ continuation_len = continuation_latent.shape[1]
64
+ prefix_latent = torch.cat([continuation_latent, prefix_latent], dim=1)
65
+ start_pos = continuation_len
66
+
67
+ for block_size in block_sizes:
68
+ if speaker_kv_scale is not None:
69
+ _multiply_kv_cache(kv_speaker_cond, speaker_kv_scale, speaker_kv_max_layers)
70
+ kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
71
+
72
+ full_prefix_latent = torch.cat([prefix_latent, prefix_latent, prefix_latent], dim=0)
73
+ kv_latent_full = model.get_kv_cache_latent(full_prefix_latent.to(dtype))
74
+ kv_latent_cond = [(k[:batch_size], v[:batch_size]) for k, v in kv_latent_full]
75
+
76
+ x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32, generator=rng)
77
+ if truncation_factor is not None:
78
+ x_t = x_t * truncation_factor
79
+
80
+ for i in range(num_steps):
81
+ t, t_next = t_schedule[i], t_schedule[i + 1]
82
+
83
+ has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
84
+
85
+ if has_cfg:
86
+ v_cond, v_uncond_text, v_uncond_speaker = model(
87
+ x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
88
+ t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
89
+ text_mask=full_text_mask,
90
+ speaker_mask=full_speaker_mask,
91
+ start_pos=start_pos,
92
+ kv_cache_text=kv_text_full,
93
+ kv_cache_speaker=kv_speaker_full,
94
+ kv_cache_latent=kv_latent_full,
95
+ ).float().chunk(3, dim=0)
96
+ v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker)
97
+ else:
98
+ v_pred = model(
99
+ x=x_t.to(dtype),
100
+ t=(torch.ones((batch_size,), device=device) * t).to(dtype),
101
+ text_mask=text_mask,
102
+ speaker_mask=speaker_mask,
103
+ start_pos=start_pos,
104
+ kv_cache_text=kv_text_cond,
105
+ kv_cache_speaker=kv_speaker_cond,
106
+ kv_cache_latent=kv_latent_cond,
107
+ ).float()
108
+
109
+ # optional temporal score rescaling: https://arxiv.org/pdf/2510.01184
110
+ if rescale_k is not None and rescale_sigma is not None:
111
+ v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
112
+
113
+ # optional kv speaker scaling
114
+ if speaker_kv_scale is not None and t_next < speaker_kv_min_t and t >= speaker_kv_min_t:
115
+ _multiply_kv_cache(kv_speaker_cond, 1. / speaker_kv_scale, speaker_kv_max_layers)
116
+ kv_speaker_full = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond, kv_speaker_cond)
117
+
118
+ x_t = x_t + v_pred * (t_next - t)
119
+
120
+ prefix_latent[:, start_pos:start_pos + block_size] = x_t
121
+ start_pos += block_size
122
+
123
+ return prefix_latent
124
+
125
+
126
+ if __name__ == "__main__":
127
+ import torchaudio
128
+ from inference import (
129
+ load_model_from_hf,
130
+ load_fish_ae_from_hf,
131
+ load_pca_state_from_hf,
132
+ load_audio,
133
+ get_text_input_ids_and_mask,
134
+ get_speaker_latent_and_mask,
135
+ ae_encode,
136
+ ae_decode,
137
+ crop_audio_to_flattening_point,
138
+ )
139
+
140
+ model = load_model_from_hf()
141
+ fish_ae = load_fish_ae_from_hf()
142
+ pca_state = load_pca_state_from_hf()
143
+
144
+
145
+ # example 1, generate 320 in three blocks
146
+
147
+ speaker_audio_path = "/path/to/speaker/audio.wav"
148
+ speaker_audio = load_audio(speaker_audio_path).cuda()
149
+ speaker_latent, speaker_mask = get_speaker_latent_and_mask(fish_ae, pca_state, speaker_audio)
150
+
151
+ text = "[S1] Alright, I'm going to demo this new model called Echo TTS."
152
+ text_input_ids, text_mask = get_text_input_ids_and_mask([text], max_length=None, device="cuda")
153
+
154
+ latent_out = sample_blockwise_euler_cfg_independent_guidances(
155
+ model=model,
156
+ speaker_latent=speaker_latent,
157
+ speaker_mask=speaker_mask,
158
+ text_input_ids=text_input_ids,
159
+ text_mask=text_mask,
160
+ rng_seed=0,
161
+ block_sizes=[128, 128, 64], # (sums to 320, so will be ~15 seconds; supports up to 640)
162
+ num_steps=40,
163
+ cfg_scale_text=3.0,
164
+ cfg_scale_speaker=5.0,
165
+ cfg_min_t=0.5,
166
+ cfg_max_t=1.0,
167
+ truncation_factor=0.8,
168
+ rescale_k=None,
169
+ rescale_sigma=None,
170
+ speaker_kv_scale=None,
171
+ speaker_kv_max_layers=None,
172
+ speaker_kv_min_t=None,
173
+ )
174
+ audio_out = ae_decode(fish_ae, pca_state, latent_out)
175
+ audio_out = crop_audio_to_flattening_point(audio_out, latent_out[0])
176
+ torchaudio.save("output_blockwise.wav", audio_out[0].cpu(), 44100)
177
+
178
+
179
+
180
+ # ___________________________________________________________
181
+ # example 2: with continuation latent (use same speaker audio as first example, generate from partial output of first example)
182
+
183
+ continuation_audio_path = "output_blockwise.wav" # can be any path
184
+ continuation_audio = load_audio(continuation_audio_path).cuda()
185
+ continuation_latent, continuation_mask = get_speaker_latent_and_mask(fish_ae, pca_state, continuation_audio)
186
+
187
+ continuation_latent = continuation_latent[:, :continuation_mask.sum()]
188
+
189
+ text = "[S1] Alright, I'm going to demo this new model called Echo TTS, and now, we're going to continue from the audio we already generated and add some more text."
190
+ # NOTE this MUST include the text from the continuation prefix. can use https://huggingface.co/jordand/whisper-d-v1a to get in-distribution transcription automatically.
191
+
192
+ text_input_ids, text_mask = get_text_input_ids_and_mask([text], max_length=None, device="cuda")
193
+
194
+ continuation_block_sizes = [256] # (generate up to 12 more seconds)
195
+ # NOTE: these do not include the continuation latent length, so sum(block_sizes) + continuation_latent.shape[1] should be < 640 (to be in-distribution with training data)
196
+
197
+ latent_out_continued = sample_blockwise_euler_cfg_independent_guidances(
198
+ model=model,
199
+ speaker_latent=speaker_latent,
200
+ speaker_mask=speaker_mask,
201
+ text_input_ids=text_input_ids,
202
+ text_mask=text_mask,
203
+ rng_seed=0,
204
+ block_sizes=continuation_block_sizes,
205
+ num_steps=40,
206
+ cfg_scale_text=3.0,
207
+ cfg_scale_speaker=3.0,
208
+ cfg_min_t=0.5,
209
+ cfg_max_t=1.0,
210
+ truncation_factor=0.8,
211
+ rescale_k=None,
212
+ rescale_sigma=None,
213
+ speaker_kv_scale=None,
214
+ speaker_kv_max_layers=None,
215
+ speaker_kv_min_t=None,
216
+ continuation_latent=continuation_latent,
217
+ )
218
+ audio_out_continued = ae_decode(fish_ae, pca_state, latent_out_continued)
219
+ audio_out_continued = crop_audio_to_flattening_point(audio_out_continued, latent_out_continued[0])
220
+ torchaudio.save("output_blockwise_continued.wav", audio_out_continued[0].cpu(), 44100)
model.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+
7
+ import torch.nn.functional as F
8
+
9
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
10
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)] / dim))
11
+ t = torch.arange(end)
12
+ freqs = torch.outer(t, freqs)
13
+ freqs_cis = torch.complex(torch.cos(freqs), torch.sin(freqs))
14
+ return freqs_cis
15
+
16
+
17
+ def apply_rotary_emb(
18
+ x: torch.Tensor,
19
+ freqs_cis: torch.Tensor,
20
+ ) -> torch.Tensor:
21
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:3], -1, 2))
22
+ x_ = x_ * freqs_cis[..., None, :]
23
+ x_ = torch.view_as_real(x_).reshape(x.shape)
24
+ return x_.type_as(x)
25
+
26
+
27
+ def get_timestep_embedding(
28
+ timestep: torch.Tensor,
29
+ embed_size: int,
30
+ ) -> torch.Tensor:
31
+ assert embed_size % 2 == 0
32
+
33
+ half = embed_size // 2
34
+
35
+ freqs = 1000 * torch.exp(
36
+ -torch.log(torch.tensor(10000.0)) *
37
+ torch.arange(start=0, end=half, dtype=torch.float32) / half
38
+ ).to(timestep.device)
39
+
40
+ args = timestep[..., None] * freqs[None]
41
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
+
43
+ return embedding.to(timestep.dtype)
44
+
45
+
46
+ class LowRankAdaLN(nn.Module):
47
+ def __init__(
48
+ self,
49
+ model_size: int,
50
+ rank: int,
51
+ eps: float
52
+ ):
53
+ super().__init__()
54
+ self.eps = eps
55
+
56
+ self.shift_down = nn.Linear(model_size, rank, bias=False)
57
+ self.scale_down = nn.Linear(model_size, rank, bias=False)
58
+ self.gate_down = nn.Linear(model_size, rank, bias=False)
59
+
60
+ self.shift_up = nn.Linear(rank, model_size, bias=True)
61
+ self.scale_up = nn.Linear(rank, model_size, bias=True)
62
+ self.gate_up = nn.Linear(rank, model_size, bias=True)
63
+
64
+ def forward(
65
+ self,
66
+ x: torch.Tensor,
67
+ cond_embed: torch.Tensor,
68
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
69
+
70
+ shift, scale, gate = cond_embed.chunk(3, dim=-1)
71
+
72
+ shift = self.shift_up(self.shift_down(F.silu(shift))) + shift
73
+ scale = self.scale_up(self.scale_down(F.silu(scale))) + scale
74
+ gate = self.gate_up(self.gate_down(F.silu(gate))) + gate
75
+
76
+ x_dtype = x.dtype
77
+ x = x.float()
78
+ x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps)
79
+ x = x * (scale + 1) + shift
80
+
81
+ gate = torch.tanh(gate)
82
+
83
+ return x.to(x_dtype), gate
84
+
85
+
86
+ class RMSNorm(nn.Module): # could also just use torch rmsnorm
87
+ def __init__(
88
+ self,
89
+ model_size: int | Tuple[int, int],
90
+ eps: float
91
+ ):
92
+ super().__init__()
93
+ self.eps = eps
94
+
95
+ if isinstance(model_size, int):
96
+ model_size = (model_size, )
97
+ self.weight = nn.Parameter(torch.ones(model_size))
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ x_dtype = x.dtype
101
+ x = x.float()
102
+ x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps)
103
+ x = x * self.weight
104
+ return x.to(x_dtype)
105
+
106
+ class SelfAttention(nn.Module):
107
+ def __init__(
108
+ self,
109
+ model_size: int,
110
+ num_heads: int,
111
+ is_causal: bool,
112
+ norm_eps: float
113
+ ):
114
+ super().__init__()
115
+ self.num_heads = num_heads
116
+ self.is_causal = is_causal
117
+
118
+ self.wq = nn.Linear(model_size, model_size, bias=False)
119
+ self.wk = nn.Linear(model_size, model_size, bias=False)
120
+ self.wv = nn.Linear(model_size, model_size, bias=False)
121
+ self.wo = nn.Linear(model_size, model_size, bias=False)
122
+ self.gate = nn.Linear(model_size, model_size, bias=False)
123
+
124
+ assert model_size % num_heads == 0
125
+ self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
126
+ self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
127
+
128
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
129
+
130
+ batch_size, seq_len = x.shape[:2]
131
+
132
+ xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1)
133
+ xk = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1)
134
+ xv = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1)
135
+
136
+ gate = self.gate(x)
137
+
138
+ xq = self.q_norm(xq)
139
+ xk = self.k_norm(xk)
140
+
141
+ xq = apply_rotary_emb(xq, freqs_cis[:seq_len])
142
+ xk = apply_rotary_emb(xk, freqs_cis[:seq_len])
143
+
144
+ if mask is not None:
145
+ assert mask.ndim == 2 # (b, s)
146
+ mask = mask[:, None, None]
147
+
148
+ output = F.scaled_dot_product_attention(
149
+ query=xq.transpose(1, 2),
150
+ key=xk.transpose(1, 2),
151
+ value=xv.transpose(1, 2),
152
+ attn_mask=mask,
153
+ is_causal=self.is_causal
154
+ ).transpose(1, 2)
155
+
156
+ output = output.reshape(batch_size, seq_len, -1)
157
+ output = output * torch.sigmoid(gate)
158
+
159
+ output = self.wo(output)
160
+
161
+ return output
162
+
163
+ class JointAttention(nn.Module):
164
+ def __init__(
165
+ self,
166
+ model_size: int,
167
+ num_heads: int,
168
+ text_model_size: int,
169
+ speaker_model_size: int,
170
+ speaker_patch_size: int,
171
+ norm_eps: float
172
+ ):
173
+ super().__init__()
174
+ self.speaker_patch_size = speaker_patch_size
175
+ self.num_heads = num_heads
176
+
177
+ self.wq = nn.Linear(model_size, model_size, bias=False)
178
+ self.wk = nn.Linear(model_size, model_size, bias=False)
179
+ self.wv = nn.Linear(model_size, model_size, bias=False)
180
+
181
+ self.wk_text = nn.Linear(text_model_size, model_size, bias=False)
182
+ self.wv_text = nn.Linear(text_model_size, model_size, bias=False)
183
+
184
+ self.wk_speaker = nn.Linear(speaker_model_size, model_size, bias=False)
185
+ self.wv_speaker = nn.Linear(speaker_model_size, model_size, bias=False)
186
+
187
+ self.wk_latent = nn.Linear(speaker_model_size, model_size, bias=False)
188
+ self.wv_latent = nn.Linear(speaker_model_size, model_size, bias=False)
189
+
190
+ assert model_size % num_heads == 0
191
+ self.head_dim = model_size // num_heads
192
+ self.q_norm = RMSNorm((num_heads, self.head_dim), eps=norm_eps)
193
+ self.k_norm = RMSNorm((num_heads, self.head_dim), eps=norm_eps)
194
+
195
+ self.gate = nn.Linear(model_size, model_size, bias=False)
196
+
197
+ self.wo = nn.Linear(model_size, model_size, bias=False)
198
+
199
+ def _apply_rotary_half(self, y: torch.Tensor, fc: torch.Tensor) -> torch.Tensor:
200
+ y1, y2 = y.chunk(2, dim=-2)
201
+ y1 = apply_rotary_emb(y1, fc)
202
+ return torch.cat([y1, y2], dim=-2)
203
+
204
+ def forward(
205
+ self,
206
+ x: torch.Tensor,
207
+ text_mask: torch.Tensor,
208
+ speaker_mask: torch.Tensor,
209
+ freqs_cis: torch.Tensor,
210
+ kv_cache_text: Tuple[torch.Tensor, torch.Tensor],
211
+ kv_cache_speaker: Tuple[torch.Tensor, torch.Tensor],
212
+ start_pos: int | None,
213
+ kv_cache_latent: Tuple[torch.Tensor, torch.Tensor] | None
214
+ ) -> torch.Tensor:
215
+ batch_size, seq_len = x.shape[:2]
216
+
217
+ xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1)
218
+ xk_self = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1)
219
+ xv_self = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1)
220
+
221
+ xq = self.q_norm(xq)
222
+ xk_self = self.k_norm(xk_self)
223
+
224
+ gate = self.gate(x)
225
+
226
+ if start_pos is None:
227
+ start_pos = 0
228
+
229
+ freqs_q = freqs_cis[start_pos : start_pos + seq_len]
230
+
231
+ xq = self._apply_rotary_half(xq, freqs_q)
232
+ xk_self = self._apply_rotary_half(xk_self, freqs_q)
233
+
234
+ xk_text, xv_text = kv_cache_text
235
+ xk_speaker, xv_speaker = kv_cache_speaker
236
+
237
+ if kv_cache_latent is None or kv_cache_latent[0].shape [1] == 0:
238
+ xk_latent = torch.zeros((batch_size, 0, self.num_heads, xq.shape[-1]), device=x.device, dtype=x.dtype)
239
+ xv_latent = torch.zeros((batch_size, 0, self.num_heads, xq.shape[-1]), device=x.device, dtype=x.dtype)
240
+ latent_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=x.device)
241
+ else:
242
+ xk_latent, xv_latent = kv_cache_latent
243
+ latent_positions = torch.arange(xk_latent.shape[1], device=x.device, dtype=torch.long) * self.speaker_patch_size
244
+ latent_mask = (latent_positions[None, :] < start_pos).expand(batch_size, xk_latent.shape[1])
245
+
246
+ xk = torch.cat([xk_self, xk_latent, xk_text, xk_speaker], dim=1)
247
+ xv = torch.cat([xv_self, xv_latent, xv_text, xv_speaker], dim=1)
248
+
249
+ self_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=x.device)
250
+
251
+
252
+ mask = torch.cat([self_mask, latent_mask, text_mask, speaker_mask], dim=1)
253
+ mask = mask[:, None, None]
254
+
255
+ output = F.scaled_dot_product_attention(
256
+ query=xq.transpose(1, 2),
257
+ key=xk.transpose(1, 2),
258
+ value=xv.transpose(1, 2),
259
+ attn_mask=mask,
260
+ is_causal=False
261
+ ).transpose(1, 2)
262
+
263
+ output = output.reshape(batch_size, seq_len, -1)
264
+ output = output * torch.sigmoid(gate)
265
+
266
+ output = self.wo(output)
267
+
268
+ return output
269
+
270
+ def get_kv_cache_text(self, text_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
271
+ batch_size = text_state.shape[0]
272
+ xk = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
273
+ xv = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
274
+ xk = self.k_norm(xk)
275
+ return xk, xv
276
+
277
+ def get_kv_cache_speaker(self, speaker_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
278
+ batch_size = speaker_state.shape[0]
279
+ xk = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
280
+ xv = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
281
+ xk = self.k_norm(xk)
282
+ return xk, xv
283
+
284
+ def get_kv_cache_latent(self, latent_state: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
285
+ batch_size = latent_state.shape[0]
286
+ seq_len = latent_state.shape[1]
287
+ xk = self.wk_latent(latent_state).reshape(batch_size, seq_len, self.num_heads, -1)
288
+ xv = self.wv_latent(latent_state).reshape(batch_size, seq_len, self.num_heads, -1)
289
+ xk = self.k_norm(xk)
290
+
291
+ xk = self._apply_rotary_half(xk, freqs_cis)
292
+
293
+ return xk, xv
294
+
295
+
296
+ class MLP(nn.Module):
297
+ def __init__(
298
+ self,
299
+ model_size: int,
300
+ intermediate_size: int
301
+ ):
302
+ super().__init__()
303
+ self.w1 = nn.Linear(model_size, intermediate_size, bias=False)
304
+ self.w3 = nn.Linear(model_size, intermediate_size, bias=False)
305
+ self.w2 = nn.Linear(intermediate_size, model_size, bias=False)
306
+
307
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
308
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
309
+
310
+
311
+ class EncoderTransformerBlock(nn.Module):
312
+ def __init__(
313
+ self,
314
+ model_size: int,
315
+ num_heads: int,
316
+ intermediate_size: int,
317
+ is_causal: bool,
318
+ norm_eps: float
319
+ ):
320
+ super().__init__()
321
+ self.attention = SelfAttention(
322
+ model_size=model_size,
323
+ num_heads=num_heads,
324
+ is_causal=is_causal,
325
+ norm_eps=norm_eps
326
+ )
327
+ self.mlp = MLP(
328
+ model_size=model_size,
329
+ intermediate_size=intermediate_size
330
+ )
331
+
332
+ self.attention_norm = RMSNorm(model_size, norm_eps)
333
+ self.mlp_norm = RMSNorm(model_size, norm_eps)
334
+
335
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
336
+ x = x + self.attention(self.attention_norm(x), mask, freqs_cis)
337
+ x = x + self.mlp(self.mlp_norm(x))
338
+
339
+ return x
340
+
341
+ class TransformerBlock(nn.Module):
342
+ def __init__(
343
+ self,
344
+ model_size: int,
345
+ num_heads: int,
346
+ intermediate_size: int,
347
+ norm_eps: float,
348
+ text_model_size: int,
349
+ speaker_model_size: int,
350
+ speaker_patch_size: int,
351
+ adaln_rank: int,
352
+ ):
353
+ super().__init__()
354
+ self.attention = JointAttention(
355
+ model_size=model_size,
356
+ num_heads=num_heads,
357
+ text_model_size=text_model_size,
358
+ speaker_model_size=speaker_model_size,
359
+ speaker_patch_size=speaker_patch_size,
360
+ norm_eps=norm_eps
361
+ )
362
+
363
+ self.mlp = MLP(
364
+ model_size=model_size,
365
+ intermediate_size=intermediate_size
366
+ )
367
+
368
+ self.attention_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps)
369
+ self.mlp_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps)
370
+
371
+ def forward(
372
+ self,
373
+ x: torch.Tensor,
374
+ cond_embed: torch.Tensor,
375
+ text_mask: torch.Tensor,
376
+ speaker_mask: torch.Tensor,
377
+ freqs_cis: torch.Tensor,
378
+ kv_cache_text: Tuple[torch.Tensor, torch.Tensor],
379
+ kv_cache_speaker: Tuple[torch.Tensor, torch.Tensor],
380
+ start_pos: int | None,
381
+ kv_cache_latent: Tuple[torch.Tensor, torch.Tensor] | None,
382
+ ) -> torch.Tensor:
383
+
384
+ x_norm, attention_gate = self.attention_adaln(x, cond_embed)
385
+ x = x + attention_gate * self.attention(x_norm, text_mask, speaker_mask, freqs_cis, kv_cache_text, kv_cache_speaker, start_pos, kv_cache_latent)
386
+
387
+ x_norm, mlp_gate = self.mlp_adaln(x, cond_embed)
388
+ x = x + mlp_gate * self.mlp(x_norm)
389
+
390
+ return x
391
+
392
+ class TextEncoder(nn.Module):
393
+ def __init__(
394
+ self,
395
+ vocab_size: int,
396
+ model_size: int,
397
+ num_layers: int,
398
+ num_heads: int,
399
+ intermediate_size: int,
400
+ norm_eps: float,
401
+ ):
402
+ super().__init__()
403
+ self.text_embedding = nn.Embedding(vocab_size, model_size)
404
+
405
+ self.blocks = nn.ModuleList()
406
+ for i in range(num_layers):
407
+ block = EncoderTransformerBlock(
408
+ model_size=model_size,
409
+ num_heads=num_heads,
410
+ intermediate_size=intermediate_size,
411
+ is_causal=False,
412
+ norm_eps=norm_eps
413
+ )
414
+ self.blocks.append(block)
415
+
416
+ self.head_dim = model_size // num_heads
417
+
418
+
419
+ def forward(self, input_ids: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
420
+ x = self.text_embedding(input_ids)
421
+
422
+ freqs_cis = precompute_freqs_cis(self.head_dim, input_ids.shape[1]).to(x.device) # could cache
423
+
424
+ for block in self.blocks:
425
+ x = block(x, mask, freqs_cis)
426
+
427
+ return x
428
+
429
+ class SpeakerEncoder(nn.Module):
430
+ def __init__(
431
+ self,
432
+ latent_size: int,
433
+ patch_size: int,
434
+ model_size: int,
435
+ num_layers: int,
436
+ num_heads: int,
437
+ intermediate_size: int,
438
+ norm_eps: float,
439
+ ):
440
+ super().__init__()
441
+ self.patch_size = patch_size
442
+
443
+ self.in_proj = nn.Linear(latent_size * patch_size, model_size, bias=True)
444
+
445
+ self.blocks = nn.ModuleList()
446
+ for i in range(num_layers):
447
+ block = EncoderTransformerBlock(
448
+ model_size=model_size,
449
+ num_heads=num_heads,
450
+ intermediate_size=intermediate_size,
451
+ is_causal=True,
452
+ norm_eps=norm_eps
453
+ )
454
+ self.blocks.append(block)
455
+
456
+ self.head_dim = model_size // num_heads
457
+
458
+ def forward(self, latent: torch.Tensor) -> torch.Tensor:
459
+ x = latent.reshape(*latent.shape[:-2], latent.shape[-2] // self.patch_size, latent.shape[-1] * self.patch_size)
460
+
461
+ x = self.in_proj(x)
462
+ x = x / 6. # this helped with initial activation dynamics in early ablations, could also bake into in_proj
463
+
464
+ freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device) # could cache
465
+
466
+ for block in self.blocks:
467
+ x = block(x, None, freqs_cis)
468
+
469
+ return x
470
+
471
+
472
+ class EchoDiT(nn.Module):
473
+ def __init__(
474
+ self,
475
+ latent_size: int,
476
+ #
477
+ model_size: int,
478
+ num_layers: int,
479
+ num_heads: int,
480
+ intermediate_size: int,
481
+ norm_eps: float,
482
+ #
483
+ text_vocab_size: int,
484
+ text_model_size: int,
485
+ text_num_layers: int,
486
+ text_num_heads: int,
487
+ text_intermediate_size: int,
488
+ #
489
+ speaker_patch_size: int,
490
+ speaker_model_size: int,
491
+ speaker_num_layers: int,
492
+ speaker_num_heads: int,
493
+ speaker_intermediate_size: int,
494
+ #
495
+ timestep_embed_size: int,
496
+ adaln_rank: int,
497
+ ):
498
+ super().__init__()
499
+ self.speaker_patch_size = speaker_patch_size
500
+ self.timestep_embed_size = timestep_embed_size
501
+
502
+ self.text_encoder = TextEncoder(
503
+ vocab_size=text_vocab_size,
504
+ model_size=text_model_size,
505
+ num_layers=text_num_layers,
506
+ num_heads=text_num_heads,
507
+ intermediate_size=text_intermediate_size,
508
+ norm_eps=norm_eps,
509
+ )
510
+ self.speaker_encoder = SpeakerEncoder(
511
+ latent_size=latent_size,
512
+ patch_size=speaker_patch_size,
513
+ model_size=speaker_model_size,
514
+ num_layers=speaker_num_layers,
515
+ num_heads=speaker_num_heads,
516
+ intermediate_size=speaker_intermediate_size,
517
+ norm_eps=norm_eps,
518
+ )
519
+ self.latent_encoder = SpeakerEncoder(
520
+ latent_size=latent_size,
521
+ patch_size=speaker_patch_size,
522
+ model_size=speaker_model_size,
523
+ num_layers=speaker_num_layers,
524
+ num_heads=speaker_num_heads,
525
+ intermediate_size=speaker_intermediate_size,
526
+ norm_eps=norm_eps,
527
+ )
528
+ self.text_norm = RMSNorm(text_model_size, norm_eps)
529
+ self.speaker_norm = RMSNorm(speaker_model_size, norm_eps)
530
+ self.latent_norm = RMSNorm(speaker_model_size, norm_eps)
531
+
532
+ self.cond_module = nn.Sequential(
533
+ nn.Linear(timestep_embed_size, model_size, bias=False),
534
+ nn.SiLU(),
535
+ nn.Linear(model_size, model_size, bias=False),
536
+ nn.SiLU(),
537
+ nn.Linear(model_size, model_size * 3, bias=False),
538
+ )
539
+
540
+ self.in_proj = nn.Linear(latent_size, model_size, bias=True)
541
+
542
+ self.blocks = nn.ModuleList()
543
+ for i in range(num_layers):
544
+ block = TransformerBlock(
545
+ model_size=model_size,
546
+ num_heads=num_heads,
547
+ intermediate_size=intermediate_size,
548
+ norm_eps=norm_eps,
549
+ text_model_size=text_model_size,
550
+ speaker_model_size=speaker_model_size,
551
+ speaker_patch_size=speaker_patch_size,
552
+ adaln_rank=adaln_rank,
553
+ )
554
+ self.blocks.append(block)
555
+
556
+ self.out_norm = RMSNorm(model_size, norm_eps)
557
+ self.out_proj = nn.Linear(model_size, latent_size, bias=True)
558
+
559
+ self.head_dim = model_size // num_heads
560
+
561
+
562
+
563
+ def forward(
564
+ self,
565
+ x: torch.Tensor,
566
+ t: torch.Tensor,
567
+ text_mask: torch.Tensor,
568
+ speaker_mask: torch.Tensor,
569
+ kv_cache_text: List[Tuple[torch.Tensor, torch.Tensor]],
570
+ kv_cache_speaker: List[Tuple[torch.Tensor, torch.Tensor]],
571
+ start_pos: int | None = None,
572
+ kv_cache_latent: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,
573
+ ) -> torch.Tensor:
574
+
575
+ if start_pos is None:
576
+ start_pos = 0
577
+
578
+ max_pos = start_pos + x.shape[1]
579
+ freqs_cis = precompute_freqs_cis(self.head_dim, max_pos).to(x.device) # could cache
580
+
581
+ speaker_mask = speaker_mask[..., ::self.speaker_patch_size]
582
+
583
+ cond_embed = self.cond_module(get_timestep_embedding(t, self.timestep_embed_size))
584
+ cond_embed = cond_embed[:, None]
585
+
586
+ x = self.in_proj(x)
587
+
588
+ for i, block in enumerate(self.blocks):
589
+ x = block(
590
+ x=x,
591
+ cond_embed=cond_embed,
592
+ text_mask=text_mask,
593
+ speaker_mask=speaker_mask,
594
+ freqs_cis=freqs_cis,
595
+ kv_cache_text=kv_cache_text[i],
596
+ kv_cache_speaker=kv_cache_speaker[i],
597
+ start_pos=start_pos,
598
+ kv_cache_latent=kv_cache_latent[i] if kv_cache_latent is not None else None,
599
+ )
600
+
601
+ x = self.out_norm(x)
602
+ x = self.out_proj(x)
603
+
604
+ return x.float()
605
+
606
+ def get_kv_cache_text(
607
+ self,
608
+ text_input_ids: torch.Tensor,
609
+ text_mask: torch.Tensor | None,
610
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
611
+ text_state = self.text_encoder(text_input_ids, text_mask)
612
+ text_state = self.text_norm(text_state)
613
+ return [block.attention.get_kv_cache_text(text_state) for block in self.blocks]
614
+
615
+ def get_kv_cache_speaker(
616
+ self,
617
+ speaker_latent: torch.Tensor,
618
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
619
+ speaker_state = self.speaker_encoder(speaker_latent)
620
+ speaker_state = self.speaker_norm(speaker_state)
621
+ return [block.attention.get_kv_cache_speaker(speaker_state) for block in self.blocks]
622
+
623
+ def get_kv_cache_latent(
624
+ self,
625
+ prefix_latent: torch.Tensor,
626
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
627
+ latent_state = self.latent_encoder(prefix_latent)
628
+ latent_state = self.latent_norm(latent_state)
629
+
630
+ seq_len = latent_state.shape[1]
631
+ max_pos = seq_len * self.speaker_patch_size
632
+ freqs_cis = precompute_freqs_cis(self.head_dim, max_pos).to(latent_state.device) # could cache
633
+ positions = torch.arange(seq_len, device=latent_state.device) * self.speaker_patch_size
634
+ freqs_latent = freqs_cis[positions]
635
+
636
+ return [block.attention.get_kv_cache_latent(latent_state, freqs_latent) for block in self.blocks]
637
+
638
+ @property
639
+ def device(self) -> torch.device: return next(self.parameters()).device
640
+
641
+ @property
642
+ def dtype(self) -> torch.dtype: return next(self.parameters()).dtype
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.9.1
2
+ torchaudio>=2.9.1
3
+ torchcodec>=0.8.1
4
+ huggingface-hub
5
+ numpy
6
+ safetensors
7
+ einops
8
+ gradio==5.49.1
sampler_presets.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Independent-High-Speaker-CFG": {
3
+ "num_steps": "40",
4
+ "cfg_scale_text": "3.0",
5
+ "cfg_scale_speaker": "8.0",
6
+ "cfg_min_t": "0.5",
7
+ "cfg_max_t": "1.0",
8
+ "truncation_factor": "1.",
9
+ "rescale_k": "1.",
10
+ "rescale_sigma": "3.0"
11
+ },
12
+ "Independent-High-Speaker-CFG-Flat": {
13
+ "num_steps": "40",
14
+ "cfg_scale_text": "3.0",
15
+ "cfg_scale_speaker": "8.0",
16
+ "cfg_min_t": "0.5",
17
+ "cfg_max_t": "1.0",
18
+ "truncation_factor": "0.8",
19
+ "rescale_k": "1.2",
20
+ "rescale_sigma": "3.0"
21
+ },
22
+ "Independent-High-CFG": {
23
+ "num_steps": "40",
24
+ "cfg_scale_text": "8.0",
25
+ "cfg_scale_speaker": "8.0",
26
+ "cfg_min_t": "0.5",
27
+ "cfg_max_t": "1.0",
28
+ "truncation_factor": "1.",
29
+ "rescale_k": "1.",
30
+ "rescale_sigma": "3.0"
31
+ },
32
+ "Independent-High-CFG-Flat": {
33
+ "num_steps": "40",
34
+ "cfg_scale_text": "8.0",
35
+ "cfg_scale_speaker": "8.0",
36
+ "cfg_min_t": "0.5",
37
+ "cfg_max_t": "1.0",
38
+ "truncation_factor": "0.8",
39
+ "rescale_k": "1.2",
40
+ "rescale_sigma": "3.0"
41
+ },
42
+ "Independent-Low-CFG": {
43
+ "num_steps": "40",
44
+ "cfg_scale_text": "3.0",
45
+ "cfg_scale_speaker": "3.0",
46
+ "cfg_min_t": "0.5",
47
+ "cfg_max_t": "1.0",
48
+ "truncation_factor": "1.",
49
+ "rescale_k": "1.",
50
+ "rescale_sigma": "3.0"
51
+ },
52
+ "Independent-Low-CFG-Flat": {
53
+ "num_steps": "40",
54
+ "cfg_scale_text": "3.0",
55
+ "cfg_scale_speaker": "3.0",
56
+ "cfg_min_t": "0.5",
57
+ "cfg_max_t": "1.0",
58
+ "truncation_factor": "0.8",
59
+ "rescale_k": "1.2",
60
+ "rescale_sigma": "3.0"
61
+ }
62
+ }
text_presets.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Reading | [S1] The old lighthouse keeper had seen many storms in his thirty years on the rock, but nothing like this. The fog rolled in thick as wool, swallowing the beam of light before it could reach the churning waves below. Then he heard it, three short bells from the channel, where no ship should be at this hour. He grabbed his lantern and peered into the mist, his heart pounding. Something was out there, something that shouldn't exist.
2
+
3
+ Reading | [S1] Deep beneath the ocean's surface, where sunlight fades to perpetual twilight, extraordinary creatures have evolved in ways that defy imagination. Bioluminescent jellyfish pulse with ethereal blue light, while giant squid hunt in the crushing darkness. At depths of over two miles, the pressure is immense, enough to collapse a submarine, yet life persists.
4
+
5
+ Reading | [S1] The telegram arrived on a Tuesday morning in June, nineteen forty-three. Margaret's hands trembled as she tore open the envelope, dreading the words she knew might be inside. Her brother had shipped out to North Africa six months ago, and his letters had grown increasingly sparse.
6
+
7
+ Cartoon | [S1] After giving everything some more thought, I've decided it's in the best interest of humanity to acquire Nexus AI. (laughs) I've spoken with the CEO and he's on board. Well (laughs), at least that's the impression he gave initially.
8
+
9
+ Single (Disfluent) | [S1] ... explore how we can design, create interfaces that are not confusing, but at the same time can be powerful. Um, you know, I think, uh, in the, the famous, um, usability book, it's, uh, it's this, um, um, oh, geez, I'm, I'm blanking on the term, uh, uh, the, the rule about, um, uh, it's like the simplicity rule. I can't recall. Oh, cognitive load maybe.
10
+
11
+ Single (Disfluent) | [S1] Uh, complacency when the motivation isn't structured properly. Like for example, if you, if you're in the cor- if you work in the corporation for many years, a lot of corporate employees, they just, they're, they're aiming for that stock vesting and they're, they're doing just a sufficient job to, to, to reach that vesting and, and they don't, they're not performing any better than that. Um, and so I think, um, that showed me an important insight. Yeah.
12
+
13
+ Single (Disfluent) | [S1] We see the pattern of revelations, major shifts. I think Neptune in Pisces, which that transit has been happening all of 2021, and Neptune will remain in the sign of Pisces until March of 2029. So it's several years more of this transit. And what it brings is a lot of things, you know, the thing that I tend to emphasize is the profound dissolution or profound changes
14
+
15
+ Single (Disfluent) | [S1] I asked her, "Do you have like a phrase you use," and she mentioned she actually does. Like when things get tense, when there's like a moment, like if her, if her roommate is like venting about work drama or just like is stressed, and her, her roommate like deals with anxiety, I'm like, "Oh, this is probably how it feels to live with me." But, um, and like if, if, if things are rough, like she'll internally just like use this practice where she's like, like, "Not my problem, not mine to carry, not mine to handle, not mine to change." Like she'll sort of repeat that. So that's interesting.
16
+
17
+ Single (Disfluent) | [S1] If I examine the, the, if, if you examine the range of options, uh, beginning from, like, say, individual all the way, right? There will be some revenue stream, uh, there will be some purchase, there'll be some hardware profit margin for someone who creates a smart product, um, uh, there will be memberships, personal and business, uh, and then there'll be usage-based, right? So I still believe that that's kinda how, those are all the metrics. To your point, what is a membership? Up to now, folks
18
+
19
+ Single (Disfluent) | [S1] I think, if, if we can keep it under 25 points allowed, sure, our odds improve significantly. We wouldn't need to put up huge numbers ourselves, or at least that's the theory. And I should, I want to share some other stats which might be a bit outside our current discussion, but regarding this compared to 2018, the team's final four games that year, they managed 18 points total.
20
+
21
+ Singing | [S1] (singing) Amazing grace, how sweet the sound, that saved a wretch like me. I once was lost, but now am found, was blind, but now I see.
22
+
23
+ Conversation | [S1] Alright then. So, so 18 years you spent in that, uh, in that role, but alongside that in, in, was it while you were working that position in '93, you started doing some work with the network? [S2] Uh, yes. It was somewhere around '93. I, I, I played tennis pretty well, you know? I, I, I competed as a tennis player. And the, I got a chance to do some broadcasting over in Brisbane.
24
+
25
+ Conversation | [S1] ... that will provide the analytics component- [S2] Right. [S1] ... to ideally get you to adopt some of their other tools. And- [S2] (laughs) [S1] ... some of those features are valuable too. [S2] That's interesting. [S1] Mailchimp, I mean, that's campaign manage-, uh, not exactly campaign management, but messaging platforms. [S2] Uh-huh. [S1] The, the companies that are, you know,
26
+
27
+ Conversation | [S1] They were like, they were pumped for it, going wild for it, and it disappeared immediately. [S2] Yeah, I think it's about people understanding what's available first. Um... [S1] I think the finish on that one too was really nice. [S2] Yeah. [S1] I mean, that was pretty awesome. [S2] Have you seen those new editions?
28
+
29
+ Conversation | [S1] He was just practicing with them and they were on rotation. [S2] So that was probably in January. [S1] I think startup stereotypes, there is some like that, but some of them, I think they need to be changed. Like we don't all work twenty-hour days. [S2] No, they just need to, it's called not, it's based in Silicon Valley. [S1] Yeah. [S2] But the stereotypes would apply if they, it was called Techlife- [S1] Palo Alto. [S2] ... Cupertino or Mountain View, California.
30
+
31
+
32
+ Conversation | [S1] That's a nice overview. [S2] We were at the downtown cinema. [S1] By that, you mean the one in Riverside? [S2] Yeah. [S1] Yeah. So not exactly downtown. [S2] Not exactly downtown, yeah. [S1] I know a little bit about that area. [S2] (laughs) [S1] You know, Millbrook doesn't have a cinema. [S2] (laughs) It's the closest one for us. It's the closest. [S1] Yeah, that's true. [S2] The most nearby. [S1] Riverside is nearby. [S2] Riverside's close. [S1] That's fair. [S2] Support nearby. [S1] You can say, say Riverside, definitely. [S2] Well, yeah, fair enough.
33
+
34
+ Conversation | [S1] But they also, they also discovered, um, they also discovered like patterns in the desert, um, near Peru, like in the Atacama Desert. [S2] Yeah. [S1] Um, and like, it was like, of like perfectly, like, geo- geometric shapes. And they're like, "Yo, this is definitely not like formed by wind. This has to be artificial." [S2] Yeah, it's too precise.
35
+
36
+ Conversation | [S1] 'Cause I, yeah, there, there has to be a way that they can just make the, the system recognize that, no, you did not earn this- [S2] (laughs) [S1] ... on your own. You still have to go and complete one if you want it for your own- [S2] Right. [S1] ... like, profile. [S2] Right. Mm-hmm. [S1] So, yeah. [S2] Um, yeah. So let's actually move into multiplayer.
37
+
38
+ Conversation | [S1] Yeah. [S2] Yeah. TRS as a whole is just relaxed. [S1] But anyway, you know that Mirror app that launched and then got removed like a month later? [S2] Mirror, what, like, to your future? [S1] Yeah. [S2] Oh. [S1] So basically, there was an app, there's a show coming out. [S2] This is a show. [S1] Coming, I don't know what it is. [S2] Yeah, yeah, yeah. [S1] Like 2026 or something. Basically, Marcus, have you heard about this? [S2] I'm sorry, I don't know. No, I don't have an, it's an app- [S1] Okay, so I'll explain. I'll explain. [S2] Yeah. [S1] For context. So there's this app that launched in terms of the show called Mirror.
39
+
40
+ Conversation | [S1] Jamie Patterson, right? [S2] No, I know where- [S1] I know where- [S2] ... Patterson works as well. I know where- [S1] I know- I know he used to work near- on this street, and this is a weird street. [S2] The only person who I don't know where they work, Jamie. But anyway, why are we even talking about who works where? [S1] It was a- it was- it was a really weird street name where Jamie worked. [S2] I- I drove past this street on my commute. [S1] No, you didn't. [S2] Yeah, I did. [S1] No, you drove past the street that my street is down the street of. [S2] Nice. There's, like, one street in Oakfield, I think I'll be able to find it, mate.