Spaces:
Running on Zero
Running on Zero
--replace-all commited on
Commit ·
1459ef5
1
Parent(s): a8bc6e2
Add Nano-TTS CPU Gradio Space
Browse files- .gitattributes +2 -0
- .gitignore +6 -0
- README.md +11 -3
- app.py +500 -0
- asserts/audio/en_1.wav +3 -0
- asserts/audio/en_2.wav +3 -0
- asserts/audio/en_3.wav +3 -0
- asserts/audio/en_4.wav +3 -0
- asserts/audio/en_5.wav +3 -0
- asserts/audio/jp_1.mp3 +3 -0
- asserts/audio/jp_2.wav +3 -0
- asserts/audio/jp_3.wav +3 -0
- asserts/audio/jp_4.wav +3 -0
- asserts/audio/jp_5.wav +3 -0
- asserts/audio/zh_1.wav +3 -0
- asserts/audio/zh_2.wav +3 -0
- asserts/audio/zh_3.wav +3 -0
- asserts/audio/zh_4.wav +3 -0
- asserts/audio/zh_5.wav +3 -0
- asserts/audio/zh_6.wav +3 -0
- nano_tts_runtime.py +727 -0
- requirements.txt +7 -0
- text_normalization_pipeline.py +195 -0
- tts_robust_normalizer_single_script.py +366 -0
- weights/codec/.gitattributes +35 -0
- weights/codec/README.md +195 -0
- weights/codec/__init__.py +1 -0
- weights/codec/config.json +304 -0
- weights/codec/configuration_moss_audio_tokenizer.py +467 -0
- weights/codec/model-00001-of-00001.safetensors +3 -0
- weights/codec/model.safetensors.index.json +382 -0
- weights/codec/modeling_moss_audio_tokenizer.py +0 -0
- weights/tts/.gitattributes +35 -0
- weights/tts/README.md +3 -0
- weights/tts/__init__.py +31 -0
- weights/tts/config.json +197 -0
- weights/tts/configuration_nanotts.py +105 -0
- weights/tts/gpt2_decoder.py +618 -0
- weights/tts/modeling_nanotts_global_local.py +0 -0
- weights/tts/prompting.py +92 -0
- weights/tts/pytorch_model.bin +3 -0
- weights/tts/special_tokens_map.json +30 -0
- weights/tts/tokenization_nanotts_sentencepiece.py +103 -0
- weights/tts/tokenizer.model +3 -0
- weights/tts/tokenizer_config.json +52 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
generated_audio/
|
| 4 |
+
.cache/
|
| 5 |
+
weights/tts/.cache/
|
| 6 |
+
weights/codec/.cache/
|
README.md
CHANGED
|
@@ -4,11 +4,19 @@ emoji: 📈
|
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
-
short_description:
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.5.1
|
| 8 |
+
python_version: "3.10"
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
license: apache-2.0
|
| 12 |
+
short_description: CPU-only MOSS TTS Nano Gradio demo with local TTS and codec weights
|
| 13 |
---
|
| 14 |
|
| 15 |
+
This Space runs Nano-TTS on CPU using the local `weights/tts` and `weights/codec` directories.
|
| 16 |
+
|
| 17 |
+
Supported modes:
|
| 18 |
+
|
| 19 |
+
- `voice_clone`: upload a reference audio file or use a built-in preset voice
|
| 20 |
+
- `continuation`: plain TTS, or prompt transcript plus prompt audio
|
| 21 |
+
|
| 22 |
+
Realtime streaming decode is disabled in this Space. Audio is returned after full synthesis finishes.
|
app.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import functools
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
|
| 12 |
+
from nano_tts_runtime import DEFAULT_VOICE, NanoTTSService, build_default_voice_presets
|
| 13 |
+
from text_normalization_pipeline import prepare_tts_request_texts
|
| 14 |
+
|
| 15 |
+
APP_DIR = Path(__file__).resolve().parent
|
| 16 |
+
CHECKPOINT_PATH = APP_DIR / "weights" / "tts"
|
| 17 |
+
AUDIO_TOKENIZER_PATH = APP_DIR / "weights" / "codec"
|
| 18 |
+
OUTPUT_DIR = Path("/tmp") / "nano-tts-space"
|
| 19 |
+
PRELOAD_ENV_VAR = "NANO_TTS_PRELOAD_AT_STARTUP"
|
| 20 |
+
|
| 21 |
+
MODE_VOICE_CLONE = "voice_clone"
|
| 22 |
+
MODE_CONTINUATION = "continuation"
|
| 23 |
+
|
| 24 |
+
_VOICE_PRESETS = build_default_voice_presets()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def build_voice_choices() -> list[tuple[str, str]]:
|
| 28 |
+
preferred: list[tuple[str, str]] = []
|
| 29 |
+
fallback: list[tuple[str, str]] = []
|
| 30 |
+
|
| 31 |
+
for preset in _VOICE_PRESETS.values():
|
| 32 |
+
if not preset.prompt_audio_path.is_file():
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
item = (f"{preset.name} - {preset.description}", preset.name)
|
| 36 |
+
fallback.append(item)
|
| 37 |
+
if preset.prompt_audio_path.suffix.lower() == ".wav":
|
| 38 |
+
preferred.append(item)
|
| 39 |
+
|
| 40 |
+
return preferred or fallback
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
VOICE_CHOICES = build_voice_choices()
|
| 44 |
+
DEFAULT_VOICE_VALUE = (
|
| 45 |
+
DEFAULT_VOICE
|
| 46 |
+
if any(value == DEFAULT_VOICE for _, value in VOICE_CHOICES)
|
| 47 |
+
else (VOICE_CHOICES[0][1] if VOICE_CHOICES else "")
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def parse_bool_env(name: str, default: bool) -> bool:
|
| 52 |
+
value = os.getenv(name)
|
| 53 |
+
if value is None:
|
| 54 |
+
return default
|
| 55 |
+
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def parse_port(value: str | None, default: int) -> int:
|
| 59 |
+
if not value:
|
| 60 |
+
return default
|
| 61 |
+
try:
|
| 62 |
+
return int(value)
|
| 63 |
+
except ValueError:
|
| 64 |
+
return default
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def maybe_delete_file(path: str | Path | None) -> None:
|
| 68 |
+
if not path:
|
| 69 |
+
return
|
| 70 |
+
try:
|
| 71 |
+
Path(path).unlink(missing_ok=True)
|
| 72 |
+
except OSError:
|
| 73 |
+
logging.warning("failed to delete temporary file: %s", path, exc_info=True)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@functools.lru_cache(maxsize=1)
|
| 77 |
+
def get_tts_service() -> NanoTTSService:
|
| 78 |
+
return NanoTTSService(
|
| 79 |
+
checkpoint_path=CHECKPOINT_PATH,
|
| 80 |
+
audio_tokenizer_path=AUDIO_TOKENIZER_PATH,
|
| 81 |
+
device="cpu",
|
| 82 |
+
dtype="float32",
|
| 83 |
+
attn_implementation="sdpa",
|
| 84 |
+
output_dir=OUTPUT_DIR,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def preload_service() -> None:
|
| 89 |
+
started_at = time.monotonic()
|
| 90 |
+
logging.info(
|
| 91 |
+
"preloading Nano-TTS model checkpoint=%s codec=%s device=cpu",
|
| 92 |
+
CHECKPOINT_PATH,
|
| 93 |
+
AUDIO_TOKENIZER_PATH,
|
| 94 |
+
)
|
| 95 |
+
get_tts_service().get_model()
|
| 96 |
+
logging.info("Nano-TTS preload finished in %.2fs", time.monotonic() - started_at)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def render_mode_hint(mode: str) -> str:
|
| 100 |
+
if mode == MODE_CONTINUATION:
|
| 101 |
+
return (
|
| 102 |
+
"Current mode: **Continuation** \n"
|
| 103 |
+
"Plain TTS uses only target text. If you upload reference audio, you must also provide its transcript."
|
| 104 |
+
)
|
| 105 |
+
return (
|
| 106 |
+
"Current mode: **Voice Clone** \n"
|
| 107 |
+
"Upload a reference audio file or use a built-in preset voice. Audio is returned only after full decoding."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def update_mode_ui(mode: str):
|
| 112 |
+
if mode == MODE_CONTINUATION:
|
| 113 |
+
return (
|
| 114 |
+
gr.update(visible=False),
|
| 115 |
+
gr.update(
|
| 116 |
+
visible=True,
|
| 117 |
+
value="",
|
| 118 |
+
placeholder="Only for continuation with reference audio.",
|
| 119 |
+
),
|
| 120 |
+
gr.update(label="Reference Audio Upload (optional; required if Prompt Transcript is set)"),
|
| 121 |
+
render_mode_hint(mode),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return (
|
| 125 |
+
gr.update(visible=True),
|
| 126 |
+
gr.update(visible=False, value=""),
|
| 127 |
+
gr.update(label="Reference Audio Upload (optional; overrides preset voice)"),
|
| 128 |
+
render_mode_hint(mode),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def validate_request(
|
| 133 |
+
*,
|
| 134 |
+
text: str,
|
| 135 |
+
mode: str,
|
| 136 |
+
prompt_text: str,
|
| 137 |
+
prompt_audio_path: str | None,
|
| 138 |
+
) -> tuple[str, str | None]:
|
| 139 |
+
normalized_text = str(text or "").strip()
|
| 140 |
+
normalized_prompt_text = str(prompt_text or "").strip()
|
| 141 |
+
has_prompt_audio = bool(prompt_audio_path)
|
| 142 |
+
|
| 143 |
+
if not normalized_text:
|
| 144 |
+
raise ValueError("Please enter text to synthesize.")
|
| 145 |
+
|
| 146 |
+
if mode == MODE_VOICE_CLONE:
|
| 147 |
+
if normalized_prompt_text:
|
| 148 |
+
raise ValueError("voice_clone mode does not use prompt_text. Leave Prompt Transcript empty.")
|
| 149 |
+
return normalized_text, None
|
| 150 |
+
|
| 151 |
+
if bool(normalized_prompt_text) != has_prompt_audio:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"continuation mode accepts either target text only, or prompt_text and reference audio together."
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return normalized_text, (normalized_prompt_text or None)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def build_status_text(
|
| 160 |
+
*,
|
| 161 |
+
result: dict[str, object],
|
| 162 |
+
prepared_texts: dict[str, object],
|
| 163 |
+
reference_source: str,
|
| 164 |
+
) -> str:
|
| 165 |
+
text_chunks = result.get("voice_clone_text_chunks") or []
|
| 166 |
+
chunk_count = len(text_chunks) if isinstance(text_chunks, list) and text_chunks else 1
|
| 167 |
+
return (
|
| 168 |
+
f"Done | mode={result['mode']} | ref={reference_source} | elapsed={result['elapsed_seconds']:.2f}s | "
|
| 169 |
+
f"sample_rate={result['sample_rate']} | attn={result['effective_global_attn_implementation']} | "
|
| 170 |
+
f"chunks={chunk_count} | normalization={prepared_texts['normalization_method']}"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def run_inference(
|
| 175 |
+
text: str,
|
| 176 |
+
mode: str,
|
| 177 |
+
voice: str,
|
| 178 |
+
prompt_audio_path: str | None,
|
| 179 |
+
prompt_text: str,
|
| 180 |
+
max_new_frames: int,
|
| 181 |
+
voice_clone_max_text_tokens: int,
|
| 182 |
+
do_sample: bool,
|
| 183 |
+
text_temperature: float,
|
| 184 |
+
text_top_p: float,
|
| 185 |
+
text_top_k: int,
|
| 186 |
+
audio_temperature: float,
|
| 187 |
+
audio_top_p: float,
|
| 188 |
+
audio_top_k: int,
|
| 189 |
+
audio_repetition_penalty: float,
|
| 190 |
+
seed: float | int,
|
| 191 |
+
):
|
| 192 |
+
generated_audio_path: str | None = None
|
| 193 |
+
try:
|
| 194 |
+
normalized_text, normalized_prompt_text = validate_request(
|
| 195 |
+
text=text,
|
| 196 |
+
mode=mode,
|
| 197 |
+
prompt_text=prompt_text,
|
| 198 |
+
prompt_audio_path=prompt_audio_path,
|
| 199 |
+
)
|
| 200 |
+
prepared_texts = prepare_tts_request_texts(
|
| 201 |
+
text=normalized_text,
|
| 202 |
+
prompt_text=normalized_prompt_text or "",
|
| 203 |
+
voice=voice,
|
| 204 |
+
enable_wetext=False,
|
| 205 |
+
text_normalizer_manager=None,
|
| 206 |
+
)
|
| 207 |
+
reference_source = (
|
| 208 |
+
"uploaded_audio"
|
| 209 |
+
if prompt_audio_path
|
| 210 |
+
else (f"preset:{voice}" if mode == MODE_VOICE_CLONE else "none")
|
| 211 |
+
)
|
| 212 |
+
normalized_seed = None
|
| 213 |
+
if seed not in {"", None}:
|
| 214 |
+
resolved_seed = int(seed)
|
| 215 |
+
if resolved_seed != 0:
|
| 216 |
+
normalized_seed = resolved_seed
|
| 217 |
+
|
| 218 |
+
result = get_tts_service().synthesize(
|
| 219 |
+
text=str(prepared_texts["text"]),
|
| 220 |
+
mode=mode,
|
| 221 |
+
voice=voice,
|
| 222 |
+
prompt_audio_path=prompt_audio_path or None,
|
| 223 |
+
prompt_text=str(prepared_texts["prompt_text"]).strip() or None,
|
| 224 |
+
max_new_frames=int(max_new_frames),
|
| 225 |
+
voice_clone_max_text_tokens=int(voice_clone_max_text_tokens),
|
| 226 |
+
do_sample=bool(do_sample),
|
| 227 |
+
text_temperature=float(text_temperature),
|
| 228 |
+
text_top_p=float(text_top_p),
|
| 229 |
+
text_top_k=int(text_top_k),
|
| 230 |
+
audio_temperature=float(audio_temperature),
|
| 231 |
+
audio_top_p=float(audio_top_p),
|
| 232 |
+
audio_top_k=int(audio_top_k),
|
| 233 |
+
audio_repetition_penalty=float(audio_repetition_penalty),
|
| 234 |
+
seed=normalized_seed,
|
| 235 |
+
attn_implementation="sdpa",
|
| 236 |
+
)
|
| 237 |
+
generated_audio_path = str(result["audio_path"])
|
| 238 |
+
return (
|
| 239 |
+
(int(result["sample_rate"]), result["waveform_numpy"]),
|
| 240 |
+
build_status_text(
|
| 241 |
+
result=result,
|
| 242 |
+
prepared_texts=prepared_texts,
|
| 243 |
+
reference_source=reference_source,
|
| 244 |
+
),
|
| 245 |
+
str(prepared_texts["normalized_text"]),
|
| 246 |
+
str(prepared_texts["normalized_prompt_text"]),
|
| 247 |
+
)
|
| 248 |
+
except Exception as exc:
|
| 249 |
+
logging.exception("Nano-TTS inference failed")
|
| 250 |
+
raise gr.Error(str(exc)) from exc
|
| 251 |
+
finally:
|
| 252 |
+
maybe_delete_file(generated_audio_path)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def build_demo():
|
| 256 |
+
custom_css = """
|
| 257 |
+
:root {
|
| 258 |
+
--bg: #f5f6f0;
|
| 259 |
+
--panel: #ffffff;
|
| 260 |
+
--ink: #15221a;
|
| 261 |
+
--muted: #5a695e;
|
| 262 |
+
--line: #d9dfd6;
|
| 263 |
+
--accent: #285943;
|
| 264 |
+
}
|
| 265 |
+
.gradio-container {
|
| 266 |
+
background:
|
| 267 |
+
radial-gradient(circle at top left, rgba(162, 198, 167, 0.18), transparent 28%),
|
| 268 |
+
linear-gradient(180deg, #f5f6f0 0%, #edf1ea 100%);
|
| 269 |
+
color: var(--ink);
|
| 270 |
+
}
|
| 271 |
+
.app-card {
|
| 272 |
+
border: 1px solid var(--line);
|
| 273 |
+
border-radius: 18px;
|
| 274 |
+
background: rgba(255, 255, 255, 0.96);
|
| 275 |
+
padding: 16px;
|
| 276 |
+
box-shadow: 0 20px 40px rgba(21, 34, 26, 0.06);
|
| 277 |
+
}
|
| 278 |
+
.app-title {
|
| 279 |
+
font-size: 24px;
|
| 280 |
+
font-weight: 700;
|
| 281 |
+
letter-spacing: 0.2px;
|
| 282 |
+
margin-bottom: 6px;
|
| 283 |
+
}
|
| 284 |
+
.app-subtitle {
|
| 285 |
+
color: var(--muted);
|
| 286 |
+
font-size: 14px;
|
| 287 |
+
line-height: 1.5;
|
| 288 |
+
}
|
| 289 |
+
#run-btn {
|
| 290 |
+
background: var(--accent);
|
| 291 |
+
border: none;
|
| 292 |
+
}
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
with gr.Blocks(title="Nano-TTS CPU Space", css=custom_css) as demo:
|
| 296 |
+
gr.Markdown(
|
| 297 |
+
"""
|
| 298 |
+
<div class="app-card">
|
| 299 |
+
<div class="app-title">Nano-TTS CPU</div>
|
| 300 |
+
<div class="app-subtitle">
|
| 301 |
+
Hugging Face Space edition backed by local <code>weights/tts</code> and <code>weights/codec</code>.
|
| 302 |
+
Realtime streaming decode is disabled; audio is returned after full synthesis.
|
| 303 |
+
</div>
|
| 304 |
+
</div>
|
| 305 |
+
"""
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
with gr.Row(equal_height=False):
|
| 309 |
+
with gr.Column(scale=3):
|
| 310 |
+
text = gr.Textbox(
|
| 311 |
+
label="Target Text",
|
| 312 |
+
lines=10,
|
| 313 |
+
placeholder="Enter the text to synthesize.",
|
| 314 |
+
)
|
| 315 |
+
mode = gr.Radio(
|
| 316 |
+
choices=[
|
| 317 |
+
("Voice Clone", MODE_VOICE_CLONE),
|
| 318 |
+
("Continuation", MODE_CONTINUATION),
|
| 319 |
+
],
|
| 320 |
+
value=MODE_VOICE_CLONE,
|
| 321 |
+
label="Inference Mode",
|
| 322 |
+
)
|
| 323 |
+
mode_hint = gr.Markdown(render_mode_hint(MODE_VOICE_CLONE))
|
| 324 |
+
voice = gr.Dropdown(
|
| 325 |
+
choices=VOICE_CHOICES,
|
| 326 |
+
value=DEFAULT_VOICE_VALUE,
|
| 327 |
+
label="Preset Voice",
|
| 328 |
+
info="Used only by voice_clone when no reference audio is uploaded.",
|
| 329 |
+
visible=True,
|
| 330 |
+
)
|
| 331 |
+
prompt_audio = gr.Audio(
|
| 332 |
+
label="Reference Audio Upload (optional; overrides preset voice)",
|
| 333 |
+
type="filepath",
|
| 334 |
+
sources=["upload"],
|
| 335 |
+
)
|
| 336 |
+
prompt_text = gr.Textbox(
|
| 337 |
+
label="Prompt Transcript",
|
| 338 |
+
lines=3,
|
| 339 |
+
visible=False,
|
| 340 |
+
placeholder="Only for continuation with reference audio.",
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
gr.Markdown(
|
| 344 |
+
"Robust text normalization is always on. WeTextProcessing is disabled in this CPU Space for a simpler deployment path."
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 348 |
+
max_new_frames = gr.Slider(
|
| 349 |
+
minimum=64,
|
| 350 |
+
maximum=512,
|
| 351 |
+
step=16,
|
| 352 |
+
value=375,
|
| 353 |
+
label="max_new_frames",
|
| 354 |
+
)
|
| 355 |
+
voice_clone_max_text_tokens = gr.Slider(
|
| 356 |
+
minimum=25,
|
| 357 |
+
maximum=200,
|
| 358 |
+
step=5,
|
| 359 |
+
value=75,
|
| 360 |
+
label="voice_clone_max_text_tokens",
|
| 361 |
+
)
|
| 362 |
+
do_sample = gr.Checkbox(
|
| 363 |
+
value=True,
|
| 364 |
+
label="Enable Sampling",
|
| 365 |
+
)
|
| 366 |
+
seed = gr.Number(
|
| 367 |
+
value=0,
|
| 368 |
+
precision=0,
|
| 369 |
+
label="Seed (0 = random)",
|
| 370 |
+
)
|
| 371 |
+
text_temperature = gr.Slider(
|
| 372 |
+
minimum=0.1,
|
| 373 |
+
maximum=2.0,
|
| 374 |
+
step=0.05,
|
| 375 |
+
value=1.0,
|
| 376 |
+
label="text_temperature",
|
| 377 |
+
)
|
| 378 |
+
text_top_p = gr.Slider(
|
| 379 |
+
minimum=0.1,
|
| 380 |
+
maximum=1.0,
|
| 381 |
+
step=0.01,
|
| 382 |
+
value=1.0,
|
| 383 |
+
label="text_top_p",
|
| 384 |
+
)
|
| 385 |
+
text_top_k = gr.Slider(
|
| 386 |
+
minimum=1,
|
| 387 |
+
maximum=100,
|
| 388 |
+
step=1,
|
| 389 |
+
value=50,
|
| 390 |
+
label="text_top_k",
|
| 391 |
+
)
|
| 392 |
+
audio_temperature = gr.Slider(
|
| 393 |
+
minimum=0.1,
|
| 394 |
+
maximum=2.0,
|
| 395 |
+
step=0.05,
|
| 396 |
+
value=0.8,
|
| 397 |
+
label="audio_temperature",
|
| 398 |
+
)
|
| 399 |
+
audio_top_p = gr.Slider(
|
| 400 |
+
minimum=0.1,
|
| 401 |
+
maximum=1.0,
|
| 402 |
+
step=0.01,
|
| 403 |
+
value=0.95,
|
| 404 |
+
label="audio_top_p",
|
| 405 |
+
)
|
| 406 |
+
audio_top_k = gr.Slider(
|
| 407 |
+
minimum=1,
|
| 408 |
+
maximum=100,
|
| 409 |
+
step=1,
|
| 410 |
+
value=25,
|
| 411 |
+
label="audio_top_k",
|
| 412 |
+
)
|
| 413 |
+
audio_repetition_penalty = gr.Slider(
|
| 414 |
+
minimum=0.8,
|
| 415 |
+
maximum=2.0,
|
| 416 |
+
step=0.05,
|
| 417 |
+
value=1.2,
|
| 418 |
+
label="audio_repetition_penalty",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
run_btn = gr.Button("Generate Speech", variant="primary", elem_id="run-btn")
|
| 422 |
+
|
| 423 |
+
with gr.Column(scale=2):
|
| 424 |
+
output_audio = gr.Audio(label="Output Audio", type="numpy")
|
| 425 |
+
status = gr.Textbox(label="Status", lines=4, interactive=False)
|
| 426 |
+
normalized_text = gr.Textbox(label="Normalized Text", lines=6, interactive=False)
|
| 427 |
+
normalized_prompt_text = gr.Textbox(
|
| 428 |
+
label="Normalized Prompt Transcript",
|
| 429 |
+
lines=4,
|
| 430 |
+
interactive=False,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
mode.change(
|
| 434 |
+
fn=update_mode_ui,
|
| 435 |
+
inputs=[mode],
|
| 436 |
+
outputs=[voice, prompt_text, prompt_audio, mode_hint],
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
run_btn.click(
|
| 440 |
+
fn=run_inference,
|
| 441 |
+
inputs=[
|
| 442 |
+
text,
|
| 443 |
+
mode,
|
| 444 |
+
voice,
|
| 445 |
+
prompt_audio,
|
| 446 |
+
prompt_text,
|
| 447 |
+
max_new_frames,
|
| 448 |
+
voice_clone_max_text_tokens,
|
| 449 |
+
do_sample,
|
| 450 |
+
text_temperature,
|
| 451 |
+
text_top_p,
|
| 452 |
+
text_top_k,
|
| 453 |
+
audio_temperature,
|
| 454 |
+
audio_top_p,
|
| 455 |
+
audio_top_k,
|
| 456 |
+
audio_repetition_penalty,
|
| 457 |
+
seed,
|
| 458 |
+
],
|
| 459 |
+
outputs=[output_audio, status, normalized_text, normalized_prompt_text],
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
return demo
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def main() -> None:
|
| 466 |
+
parser = argparse.ArgumentParser(description="Nano-TTS Hugging Face Space")
|
| 467 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 468 |
+
parser.add_argument(
|
| 469 |
+
"--port",
|
| 470 |
+
type=int,
|
| 471 |
+
default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))),
|
| 472 |
+
)
|
| 473 |
+
parser.add_argument("--share", action="store_true")
|
| 474 |
+
args = parser.parse_args()
|
| 475 |
+
|
| 476 |
+
logging.basicConfig(
|
| 477 |
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
| 478 |
+
level=logging.INFO,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
args.host = os.getenv("GRADIO_SERVER_NAME", args.host)
|
| 482 |
+
args.port = parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), args.port)
|
| 483 |
+
|
| 484 |
+
preload_enabled = parse_bool_env(PRELOAD_ENV_VAR, default=not bool(os.getenv("SPACE_ID")))
|
| 485 |
+
if preload_enabled:
|
| 486 |
+
preload_service()
|
| 487 |
+
else:
|
| 488 |
+
logging.info("Skipping model preload (set %s=1 to enable).", PRELOAD_ENV_VAR)
|
| 489 |
+
|
| 490 |
+
demo = build_demo()
|
| 491 |
+
demo.queue(max_size=4, default_concurrency_limit=1).launch(
|
| 492 |
+
server_name=args.host,
|
| 493 |
+
server_port=args.port,
|
| 494 |
+
share=args.share,
|
| 495 |
+
ssr_mode=False,
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
if __name__ == "__main__":
|
| 500 |
+
main()
|
asserts/audio/en_1.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1816ab428334ba2de49dcb8b0a10e17eb1835f7f1f7bcda13504e88f46bed1e8
|
| 3 |
+
size 249284
|
asserts/audio/en_2.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:959cce9498f2bae964ca67136c2c02c7174922813b69aa435b27ec8759b44992
|
| 3 |
+
size 694618
|
asserts/audio/en_3.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:563544e54f6dd66b24a4494fa40b8f9debd7cceb50ae47a149c14bc3610c4aff
|
| 3 |
+
size 455372
|
asserts/audio/en_4.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cdd4f0ba5a4c0499f5194f5767ffaa9e988ea912210e369f66e2812278ba45ff
|
| 3 |
+
size 458948
|
asserts/audio/en_5.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0822692aafe818424d9902419ec46bd707bddc401ca1b5a2539229cfc2852e7
|
| 3 |
+
size 5303154
|
asserts/audio/jp_1.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f2cdb58d8050a77f09f5444f43cbd17d56bf9c73d75b98cd994feb2af22dc02
|
| 3 |
+
size 96624
|
asserts/audio/jp_2.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c44a65e55b7376a87607fdea5a6a5ab735c7aef2e007d1fc02a9f50d37bf11a4
|
| 3 |
+
size 227600
|
asserts/audio/jp_3.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:585ff999d7219d6247863a1abc3b112c822fb8603e546146d788bcf14536c57e
|
| 3 |
+
size 427120
|
asserts/audio/jp_4.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58247506a362aaa347bc732d0196078e72b434046b9ddf3111c30878cdc10213
|
| 3 |
+
size 546884
|
asserts/audio/jp_5.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd4a8ef2dc90f080ec8e6abb40d4b3d40c3445d51e57a5244ee46dbba1b2dcf8
|
| 3 |
+
size 346670
|
asserts/audio/zh_1.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24c0cc85603d26017ac9c3ee89e0a03c66a193a5fdede5db74bb88f670d83723
|
| 3 |
+
size 2000754
|
asserts/audio/zh_2.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92b9312ca9fbb6f351bc57ae123118a28bd773cfd62dda9fc59f372cea786143
|
| 3 |
+
size 442068
|
asserts/audio/zh_3.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:168d988f2d60773902862e5fd29fb0fad10468925b10a98995f6feb44ceb1cff
|
| 3 |
+
size 411452
|
asserts/audio/zh_4.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:965ba9c61ffbc4dc03b6441a5e22d08d26a747ff3536d669898d3975aebc8e72
|
| 3 |
+
size 1267100
|
asserts/audio/zh_5.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:89247d4db86a7dd921f16f805fbe513e7dd12631e5402aea02a94b4fa19560e7
|
| 3 |
+
size 827036
|
asserts/audio/zh_6.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e8b1a0edd604129a6b8bef1a2f3627ad7c4c6069ebb77e50e88781a4048c9c1
|
| 3 |
+
size 285092
|
nano_tts_runtime.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
import logging
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Iterator, Optional
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from transformers import AutoModel, AutoModelForCausalLM
|
| 16 |
+
|
| 17 |
+
MOSS_AUDIO_TOKENIZER_TYPE = "moss-audio-tokenizer-nano"
|
| 18 |
+
|
| 19 |
+
APP_DIR = Path(__file__).resolve().parent
|
| 20 |
+
DEFAULT_CHECKPOINT_PATH = APP_DIR / "weights" / "tts"
|
| 21 |
+
DEFAULT_AUDIO_TOKENIZER_PATH = APP_DIR / "weights" / "codec"
|
| 22 |
+
DEFAULT_PROMPT_AUDIO_DIR = APP_DIR / "asserts" / "audio"
|
| 23 |
+
DEFAULT_OUTPUT_DIR = APP_DIR / "generated_audio"
|
| 24 |
+
|
| 25 |
+
_DEFAULT_VOICE_FILES: dict[str, tuple[str, str]] = {
|
| 26 |
+
"Junhao": ("zh_1.wav", "Chinese male voice A"),
|
| 27 |
+
"Zhiming": ("zh_2.wav", "Chinese male voice B"),
|
| 28 |
+
"Weiguo": ("zh_5.wav", "Chinese male voice C"),
|
| 29 |
+
"Xiaoyu": ("zh_3.wav", "Chinese female voice A"),
|
| 30 |
+
"Yuewen": ("zh_4.wav", "Chinese female voice B"),
|
| 31 |
+
"Lingyu": ("zh_6.wav", "Chinese female voice C"),
|
| 32 |
+
"Trump": ("en_1.wav", "Trump reference voice"),
|
| 33 |
+
"Ava": ("en_2.wav", "English female voice A"),
|
| 34 |
+
"Bella": ("en_3.wav", "English female voice B"),
|
| 35 |
+
"Adam": ("en_4.wav", "English male voice A"),
|
| 36 |
+
"Nathan": ("en_5.wav", "English male voice B"),
|
| 37 |
+
"Sakura": ("jp_1.mp3", "Japanese female voice A"),
|
| 38 |
+
"Yui": ("jp_2.wav", "Japanese female voice B"),
|
| 39 |
+
"Aoi": ("jp_3.wav", "Japanese female voice C"),
|
| 40 |
+
"Hina": ("jp_4.wav", "Japanese female voice D"),
|
| 41 |
+
"Mei": ("jp_5.wav", "Japanese female voice E"),
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
DEFAULT_VOICE = "Junhao"
|
| 45 |
+
FLASH_ATTENTION_DTYPES = {torch.float16, torch.bfloat16}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass(frozen=True)
|
| 49 |
+
class VoicePreset:
|
| 50 |
+
name: str
|
| 51 |
+
prompt_audio_path: Path
|
| 52 |
+
description: str
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def build_default_voice_presets() -> dict[str, VoicePreset]:
|
| 56 |
+
presets: dict[str, VoicePreset] = {}
|
| 57 |
+
for voice_name, (file_name, description) in _DEFAULT_VOICE_FILES.items():
|
| 58 |
+
prompt_path = (DEFAULT_PROMPT_AUDIO_DIR / file_name).resolve()
|
| 59 |
+
presets[voice_name] = VoicePreset(
|
| 60 |
+
name=voice_name,
|
| 61 |
+
prompt_audio_path=prompt_path,
|
| 62 |
+
description=description,
|
| 63 |
+
)
|
| 64 |
+
return presets
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def resolve_device(device_arg: str) -> torch.device:
|
| 68 |
+
if device_arg == "auto":
|
| 69 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 70 |
+
return torch.device(device_arg)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def resolve_dtype(dtype_arg: str, device: torch.device) -> torch.dtype:
|
| 74 |
+
if dtype_arg == "float32":
|
| 75 |
+
return torch.float32
|
| 76 |
+
if dtype_arg == "float16":
|
| 77 |
+
return torch.float16
|
| 78 |
+
if dtype_arg == "bfloat16":
|
| 79 |
+
return torch.bfloat16
|
| 80 |
+
if device.type == "cuda":
|
| 81 |
+
if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
|
| 82 |
+
return torch.bfloat16
|
| 83 |
+
return torch.float16
|
| 84 |
+
return torch.float32
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def waveform_to_numpy(waveform: torch.Tensor | np.ndarray) -> np.ndarray:
|
| 88 |
+
if torch.is_tensor(waveform):
|
| 89 |
+
array = waveform.detach().cpu().numpy()
|
| 90 |
+
else:
|
| 91 |
+
array = np.asarray(waveform)
|
| 92 |
+
if array.ndim == 1:
|
| 93 |
+
return array.astype(np.float32, copy=False)
|
| 94 |
+
if array.ndim != 2:
|
| 95 |
+
raise ValueError(f"Unsupported waveform shape: {array.shape}")
|
| 96 |
+
if array.shape[0] <= 8 and array.shape[0] < array.shape[1]:
|
| 97 |
+
array = array.T
|
| 98 |
+
return array.astype(np.float32, copy=False)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@lru_cache(maxsize=1)
|
| 102 |
+
def _has_flash_attn() -> bool:
|
| 103 |
+
try:
|
| 104 |
+
importlib.import_module("flash_attn")
|
| 105 |
+
except Exception:
|
| 106 |
+
return False
|
| 107 |
+
return True
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class NanoTTSService:
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
*,
|
| 114 |
+
checkpoint_path: str | Path = DEFAULT_CHECKPOINT_PATH,
|
| 115 |
+
audio_tokenizer_path: str | Path = DEFAULT_AUDIO_TOKENIZER_PATH,
|
| 116 |
+
device: str = "auto",
|
| 117 |
+
dtype: str = "auto",
|
| 118 |
+
attn_implementation: str = "auto",
|
| 119 |
+
output_dir: str | Path = DEFAULT_OUTPUT_DIR,
|
| 120 |
+
voice_presets: Optional[dict[str, VoicePreset]] = None,
|
| 121 |
+
) -> None:
|
| 122 |
+
self.checkpoint_path = Path(checkpoint_path).expanduser().resolve()
|
| 123 |
+
self.audio_tokenizer_path = Path(audio_tokenizer_path).expanduser().resolve()
|
| 124 |
+
self.output_dir = Path(output_dir).expanduser().resolve()
|
| 125 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 126 |
+
|
| 127 |
+
self.voice_presets = voice_presets or build_default_voice_presets()
|
| 128 |
+
self.default_voice = DEFAULT_VOICE if DEFAULT_VOICE in self.voice_presets else next(iter(self.voice_presets))
|
| 129 |
+
|
| 130 |
+
self.device = resolve_device(device)
|
| 131 |
+
self.dtype = resolve_dtype(dtype, self.device)
|
| 132 |
+
self.attn_implementation = self._resolve_attn_implementation(attn_implementation)
|
| 133 |
+
|
| 134 |
+
self._lock = threading.RLock()
|
| 135 |
+
self._model = None
|
| 136 |
+
self._audio_tokenizer = None
|
| 137 |
+
self._checkpoint_global_attn_implementation: str | None = None
|
| 138 |
+
self._checkpoint_local_attn_implementation: str | None = None
|
| 139 |
+
self._configured_global_attn_implementation: str | None = None
|
| 140 |
+
self._configured_local_attn_implementation: str | None = None
|
| 141 |
+
self._configured_audio_tokenizer_attn_implementation: str | None = None
|
| 142 |
+
self._configured_audio_tokenizer_compute_dtype: str | None = None
|
| 143 |
+
|
| 144 |
+
def _can_use_flash_attention(self) -> bool:
|
| 145 |
+
return self.device.type == "cuda" and self.dtype in FLASH_ATTENTION_DTYPES and _has_flash_attn()
|
| 146 |
+
|
| 147 |
+
def _resolve_runtime_default_attn_implementation(self) -> str:
|
| 148 |
+
return "flash_attention_2" if self._can_use_flash_attention() else "sdpa"
|
| 149 |
+
|
| 150 |
+
def _resolve_attn_implementation(self, requested: str | None) -> str | None:
|
| 151 |
+
normalized = str(requested).strip().lower() if requested is not None else "auto"
|
| 152 |
+
if not normalized or normalized in {"auto", "default", "model_default"}:
|
| 153 |
+
return None
|
| 154 |
+
if normalized not in {"sdpa", "flash_attention_2", "eager"}:
|
| 155 |
+
raise ValueError(
|
| 156 |
+
"attn_implementation must be one of: model_default/auto, sdpa, flash_attention_2, eager"
|
| 157 |
+
)
|
| 158 |
+
if normalized == "flash_attention_2":
|
| 159 |
+
if not self._can_use_flash_attention():
|
| 160 |
+
logging.warning(
|
| 161 |
+
"flash_attention_2 requires CUDA, flash_attn, and fp16/bf16; falling back to sdpa "
|
| 162 |
+
"(device=%s dtype=%s flash_attn=%s)",
|
| 163 |
+
self.device,
|
| 164 |
+
self.dtype,
|
| 165 |
+
_has_flash_attn(),
|
| 166 |
+
)
|
| 167 |
+
return "sdpa"
|
| 168 |
+
return normalized
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def _normalize_loaded_attn_implementation(attn_implementation: object) -> str:
|
| 172 |
+
normalized = str(attn_implementation).strip().lower() if attn_implementation is not None else ""
|
| 173 |
+
if not normalized or normalized == "none":
|
| 174 |
+
return "eager"
|
| 175 |
+
return normalized
|
| 176 |
+
|
| 177 |
+
def _resolve_request_attention_implementation(
|
| 178 |
+
self,
|
| 179 |
+
requested: Optional[str],
|
| 180 |
+
) -> tuple[str, str, str]:
|
| 181 |
+
normalized = str(requested).strip().lower() if requested is not None else ""
|
| 182 |
+
resolved = self._resolve_attn_implementation(normalized)
|
| 183 |
+
if resolved is not None:
|
| 184 |
+
return normalized, resolved, resolved
|
| 185 |
+
|
| 186 |
+
if self.attn_implementation is not None:
|
| 187 |
+
return self.attn_implementation, self.attn_implementation, self.attn_implementation
|
| 188 |
+
|
| 189 |
+
runtime_default = self._resolve_runtime_default_attn_implementation()
|
| 190 |
+
return "auto", runtime_default, runtime_default
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def _resolve_codec_attention_implementation(tts_attn_implementation: str) -> str:
|
| 194 |
+
return "flash_attention_2" if tts_attn_implementation == "flash_attention_2" else "sdpa"
|
| 195 |
+
|
| 196 |
+
def _resolve_codec_compute_dtype(self, codec_attn_implementation: str) -> str:
|
| 197 |
+
if codec_attn_implementation == "flash_attention_2":
|
| 198 |
+
return "bf16" if self.dtype == torch.bfloat16 else "fp16"
|
| 199 |
+
return "fp32"
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def _apply_model_attention_implementation(model, *, global_attn: str, local_attn: str) -> None:
|
| 203 |
+
if hasattr(model, "_set_attention_implementation"):
|
| 204 |
+
model._set_attention_implementation(global_attn, local_attn_implementation=local_attn)
|
| 205 |
+
|
| 206 |
+
def _install_stream_decode_budget_patch(self, model) -> None:
|
| 207 |
+
if self.device.type != "cuda":
|
| 208 |
+
return
|
| 209 |
+
|
| 210 |
+
model_cls = model.__class__
|
| 211 |
+
if getattr(model_cls, "_nanotts_stream_decode_budget_patch_installed", False):
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
compute_stream_lead = getattr(model_cls, "_compute_stream_lead_seconds", None)
|
| 215 |
+
resolve_stream_budget = getattr(model_cls, "_resolve_stream_decode_frame_budget", None)
|
| 216 |
+
if not callable(compute_stream_lead) or not callable(resolve_stream_budget):
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
def _patched_resolve_stream_decode_frame_budget(
|
| 220 |
+
*,
|
| 221 |
+
emitted_samples_total: int,
|
| 222 |
+
sample_rate: int,
|
| 223 |
+
first_audio_emitted_at,
|
| 224 |
+
) -> int:
|
| 225 |
+
# The upstream streaming policy starts with one decode frame
|
| 226 |
+
# (about 80 ms audio), which makes CUDA realtime decode emit many
|
| 227 |
+
# tiny chunks and underrun browser playback on this checkpoint.
|
| 228 |
+
lead_seconds = compute_stream_lead(
|
| 229 |
+
emitted_samples_total=emitted_samples_total,
|
| 230 |
+
sample_rate=sample_rate,
|
| 231 |
+
first_audio_emitted_at=first_audio_emitted_at,
|
| 232 |
+
)
|
| 233 |
+
if first_audio_emitted_at is None or lead_seconds < 0.20:
|
| 234 |
+
return 4
|
| 235 |
+
if lead_seconds < 0.55:
|
| 236 |
+
return 6
|
| 237 |
+
if lead_seconds < 1.10:
|
| 238 |
+
return 8
|
| 239 |
+
return 12
|
| 240 |
+
|
| 241 |
+
model_cls._nanotts_original_resolve_stream_decode_frame_budget = resolve_stream_budget
|
| 242 |
+
model_cls._resolve_stream_decode_frame_budget = staticmethod(_patched_resolve_stream_decode_frame_budget)
|
| 243 |
+
model_cls._nanotts_stream_decode_budget_patch_installed = True
|
| 244 |
+
logging.info("installed Nano-TTS CUDA streaming decode budget patch")
|
| 245 |
+
|
| 246 |
+
def _discard_loaded_model_locked(self, reason: str) -> None:
|
| 247 |
+
if self._model is None:
|
| 248 |
+
return
|
| 249 |
+
logging.warning("discarding loaded Nano-TTS model state: %s", reason)
|
| 250 |
+
self._model = None
|
| 251 |
+
if self.device.type == "cuda":
|
| 252 |
+
torch.cuda.empty_cache()
|
| 253 |
+
|
| 254 |
+
def _discard_loaded_audio_tokenizer_locked(self, reason: str) -> None:
|
| 255 |
+
if self._audio_tokenizer is None:
|
| 256 |
+
return
|
| 257 |
+
logging.warning("discarding loaded Nano-TTS audio tokenizer state: %s", reason)
|
| 258 |
+
self._audio_tokenizer = None
|
| 259 |
+
self._configured_audio_tokenizer_attn_implementation = None
|
| 260 |
+
self._configured_audio_tokenizer_compute_dtype = None
|
| 261 |
+
if self.device.type == "cuda":
|
| 262 |
+
torch.cuda.empty_cache()
|
| 263 |
+
|
| 264 |
+
def _restore_model_execution_state(self, model):
|
| 265 |
+
current_parameter = next(model.parameters(), None)
|
| 266 |
+
if current_parameter is None or current_parameter.dtype == self.dtype:
|
| 267 |
+
return model
|
| 268 |
+
self._discard_loaded_model_locked(
|
| 269 |
+
f"current_dtype={current_parameter.dtype} expected_dtype={self.dtype}; reloading checkpoint"
|
| 270 |
+
)
|
| 271 |
+
return self._load_model_locked()
|
| 272 |
+
|
| 273 |
+
def _read_model_attention_implementation(self, model) -> tuple[str, str]:
|
| 274 |
+
global_attn = self._normalize_loaded_attn_implementation(
|
| 275 |
+
getattr(getattr(model, "transformer", None), "attn_implementation", None)
|
| 276 |
+
)
|
| 277 |
+
local_attn = self._normalize_loaded_attn_implementation(
|
| 278 |
+
getattr(getattr(model, "local_transformer", None), "attn_implementation", None)
|
| 279 |
+
)
|
| 280 |
+
return global_attn, local_attn
|
| 281 |
+
|
| 282 |
+
def _ensure_paths(self) -> None:
|
| 283 |
+
if not self.checkpoint_path.exists():
|
| 284 |
+
raise FileNotFoundError(f"Nano-TTS checkpoint not found: {self.checkpoint_path}")
|
| 285 |
+
if not self.audio_tokenizer_path.exists():
|
| 286 |
+
raise FileNotFoundError(f"Audio tokenizer checkpoint not found: {self.audio_tokenizer_path}")
|
| 287 |
+
|
| 288 |
+
def _load_audio_tokenizer_locked(self, *, tts_attn_implementation: str):
|
| 289 |
+
codec_attn_implementation = self._resolve_codec_attention_implementation(tts_attn_implementation)
|
| 290 |
+
codec_compute_dtype = self._resolve_codec_compute_dtype(codec_attn_implementation)
|
| 291 |
+
|
| 292 |
+
if self._audio_tokenizer is None:
|
| 293 |
+
logging.info(
|
| 294 |
+
"loading Nano-TTS audio tokenizer checkpoint=%s device=%s attn=%s compute_dtype=%s",
|
| 295 |
+
self.audio_tokenizer_path,
|
| 296 |
+
self.device,
|
| 297 |
+
codec_attn_implementation,
|
| 298 |
+
codec_compute_dtype,
|
| 299 |
+
)
|
| 300 |
+
audio_tokenizer = AutoModel.from_pretrained(
|
| 301 |
+
str(self.audio_tokenizer_path),
|
| 302 |
+
trust_remote_code=True,
|
| 303 |
+
local_files_only=True,
|
| 304 |
+
)
|
| 305 |
+
if hasattr(audio_tokenizer, "eval"):
|
| 306 |
+
audio_tokenizer.eval()
|
| 307 |
+
self._audio_tokenizer = audio_tokenizer
|
| 308 |
+
|
| 309 |
+
audio_tokenizer = self._audio_tokenizer
|
| 310 |
+
if hasattr(audio_tokenizer, "to"):
|
| 311 |
+
audio_tokenizer = audio_tokenizer.to(self.device)
|
| 312 |
+
if hasattr(audio_tokenizer, "set_attention_implementation"):
|
| 313 |
+
audio_tokenizer.set_attention_implementation(codec_attn_implementation)
|
| 314 |
+
if hasattr(audio_tokenizer, "set_compute_dtype"):
|
| 315 |
+
audio_tokenizer.set_compute_dtype(codec_compute_dtype)
|
| 316 |
+
if hasattr(audio_tokenizer, "eval"):
|
| 317 |
+
audio_tokenizer.eval()
|
| 318 |
+
|
| 319 |
+
self._audio_tokenizer = audio_tokenizer
|
| 320 |
+
self._configured_audio_tokenizer_attn_implementation = codec_attn_implementation
|
| 321 |
+
self._configured_audio_tokenizer_compute_dtype = codec_compute_dtype
|
| 322 |
+
return self._audio_tokenizer
|
| 323 |
+
|
| 324 |
+
def _load_model_locked(self):
|
| 325 |
+
if self._model is not None:
|
| 326 |
+
return self._model
|
| 327 |
+
|
| 328 |
+
self._ensure_paths()
|
| 329 |
+
logging.info(
|
| 330 |
+
"loading Nano-TTS checkpoint=%s audio_tokenizer=%s device=%s dtype=%s attn=%s",
|
| 331 |
+
self.checkpoint_path,
|
| 332 |
+
self.audio_tokenizer_path,
|
| 333 |
+
self.device,
|
| 334 |
+
self.dtype,
|
| 335 |
+
self.attn_implementation or "model_default",
|
| 336 |
+
)
|
| 337 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 338 |
+
str(self.checkpoint_path),
|
| 339 |
+
trust_remote_code=True,
|
| 340 |
+
local_files_only=True,
|
| 341 |
+
)
|
| 342 |
+
model.to(device=self.device, dtype=self.dtype)
|
| 343 |
+
self._checkpoint_global_attn_implementation, self._checkpoint_local_attn_implementation = (
|
| 344 |
+
self._read_model_attention_implementation(model)
|
| 345 |
+
)
|
| 346 |
+
_, default_global_attn, default_local_attn = self._resolve_request_attention_implementation(None)
|
| 347 |
+
self._apply_model_attention_implementation(
|
| 348 |
+
model,
|
| 349 |
+
global_attn=default_global_attn,
|
| 350 |
+
local_attn=default_local_attn,
|
| 351 |
+
)
|
| 352 |
+
self._install_stream_decode_budget_patch(model)
|
| 353 |
+
model.eval()
|
| 354 |
+
self._configured_global_attn_implementation, self._configured_local_attn_implementation = (
|
| 355 |
+
self._read_model_attention_implementation(model)
|
| 356 |
+
)
|
| 357 |
+
self._model = model
|
| 358 |
+
return self._model
|
| 359 |
+
|
| 360 |
+
def get_model(self):
|
| 361 |
+
with self._lock:
|
| 362 |
+
return self._load_model_locked()
|
| 363 |
+
|
| 364 |
+
def list_voice_names(self) -> list[str]:
|
| 365 |
+
return list(self.voice_presets.keys())
|
| 366 |
+
|
| 367 |
+
def get_voice_preset(self, voice_name: Optional[str]) -> VoicePreset:
|
| 368 |
+
if voice_name and voice_name in self.voice_presets:
|
| 369 |
+
return self.voice_presets[voice_name]
|
| 370 |
+
return self.voice_presets[self.default_voice]
|
| 371 |
+
|
| 372 |
+
def resolve_prompt_audio_path(
|
| 373 |
+
self,
|
| 374 |
+
*,
|
| 375 |
+
voice: Optional[str] = None,
|
| 376 |
+
prompt_audio_path: Optional[str | Path] = None,
|
| 377 |
+
) -> Path:
|
| 378 |
+
if prompt_audio_path:
|
| 379 |
+
resolved = Path(prompt_audio_path).expanduser().resolve()
|
| 380 |
+
if not resolved.exists():
|
| 381 |
+
raise FileNotFoundError(f"Prompt audio not found: {resolved}")
|
| 382 |
+
return resolved
|
| 383 |
+
|
| 384 |
+
preset = self.get_voice_preset(voice)
|
| 385 |
+
if not preset.prompt_audio_path.exists():
|
| 386 |
+
raise FileNotFoundError(f"Voice preset prompt audio not found: {preset.prompt_audio_path}")
|
| 387 |
+
return preset.prompt_audio_path
|
| 388 |
+
|
| 389 |
+
def preload(self, *, voices: Optional[list[str]] = None, load_model: bool = True) -> dict[str, object]:
|
| 390 |
+
loaded_voices: list[str] = []
|
| 391 |
+
if load_model:
|
| 392 |
+
self.get_model()
|
| 393 |
+
for voice_name in voices or [self.default_voice]:
|
| 394 |
+
preset = self.get_voice_preset(voice_name)
|
| 395 |
+
if preset.prompt_audio_path.exists():
|
| 396 |
+
loaded_voices.append(preset.name)
|
| 397 |
+
return {
|
| 398 |
+
"loaded_voices": loaded_voices,
|
| 399 |
+
"device": str(self.device),
|
| 400 |
+
"dtype": str(self.dtype),
|
| 401 |
+
"attn_implementation": self.attn_implementation or "auto",
|
| 402 |
+
"checkpoint_default_attn_implementation": self._checkpoint_global_attn_implementation or "eager",
|
| 403 |
+
"checkpoint_default_local_attn_implementation": self._checkpoint_local_attn_implementation or "eager",
|
| 404 |
+
"configured_attn_implementation": self._configured_global_attn_implementation or "eager",
|
| 405 |
+
"configured_local_attn_implementation": self._configured_local_attn_implementation or "eager",
|
| 406 |
+
"configured_codec_attn_implementation": self._configured_audio_tokenizer_attn_implementation or "unknown",
|
| 407 |
+
"configured_codec_compute_dtype": self._configured_audio_tokenizer_compute_dtype or "unknown",
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
def _build_output_path(self, prefix: str) -> Path:
|
| 411 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 412 |
+
random_suffix = uuid.uuid4().hex[:8]
|
| 413 |
+
return self.output_dir / f"{prefix}_{timestamp}_{random_suffix}.wav"
|
| 414 |
+
|
| 415 |
+
def synthesize(
|
| 416 |
+
self,
|
| 417 |
+
*,
|
| 418 |
+
text: str,
|
| 419 |
+
voice: Optional[str] = None,
|
| 420 |
+
mode: str = "voice_clone",
|
| 421 |
+
output_audio_path: Optional[str | Path] = None,
|
| 422 |
+
prompt_audio_path: Optional[str | Path] = None,
|
| 423 |
+
prompt_text: Optional[str] = None,
|
| 424 |
+
max_new_frames: int = 375,
|
| 425 |
+
voice_clone_max_text_tokens: int = 75,
|
| 426 |
+
voice_clone_max_memory_per_sample_gb: float = 1.0,
|
| 427 |
+
tts_max_batch_size: int = 0,
|
| 428 |
+
codec_max_batch_size: int = 0,
|
| 429 |
+
do_sample: bool = True,
|
| 430 |
+
text_temperature: float = 1.0,
|
| 431 |
+
text_top_p: float = 1.0,
|
| 432 |
+
text_top_k: int = 50,
|
| 433 |
+
audio_temperature: float = 0.8,
|
| 434 |
+
audio_top_p: float = 0.95,
|
| 435 |
+
audio_top_k: int = 25,
|
| 436 |
+
audio_repetition_penalty: float = 1.2,
|
| 437 |
+
nq: Optional[int] = None,
|
| 438 |
+
seed: Optional[int] = None,
|
| 439 |
+
attn_implementation: Optional[str] = None,
|
| 440 |
+
) -> dict[str, object]:
|
| 441 |
+
normalized_text = str(text or "").strip()
|
| 442 |
+
if not normalized_text:
|
| 443 |
+
raise ValueError("text is required")
|
| 444 |
+
|
| 445 |
+
normalized_mode = str(mode).strip().lower()
|
| 446 |
+
if normalized_mode not in {"continuation", "voice_clone"}:
|
| 447 |
+
raise ValueError("mode must be either 'continuation' or 'voice_clone'")
|
| 448 |
+
|
| 449 |
+
effective_prompt_audio_path: Optional[Path] = None
|
| 450 |
+
resolved_voice = self.get_voice_preset(voice).name
|
| 451 |
+
if normalized_mode == "voice_clone":
|
| 452 |
+
effective_prompt_audio_path = self.resolve_prompt_audio_path(
|
| 453 |
+
voice=resolved_voice,
|
| 454 |
+
prompt_audio_path=prompt_audio_path,
|
| 455 |
+
)
|
| 456 |
+
elif prompt_audio_path is not None:
|
| 457 |
+
effective_prompt_audio_path = self.resolve_prompt_audio_path(prompt_audio_path=prompt_audio_path)
|
| 458 |
+
if not prompt_text:
|
| 459 |
+
raise ValueError("continuation mode with prompt_audio_path also requires prompt_text")
|
| 460 |
+
|
| 461 |
+
output_path = (
|
| 462 |
+
Path(output_audio_path).expanduser().resolve()
|
| 463 |
+
if output_audio_path is not None
|
| 464 |
+
else self._build_output_path(prefix=f"{resolved_voice}_{normalized_mode}")
|
| 465 |
+
)
|
| 466 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 467 |
+
|
| 468 |
+
started_at = time.monotonic()
|
| 469 |
+
with self._lock:
|
| 470 |
+
model = self._load_model_locked()
|
| 471 |
+
model = self._restore_model_execution_state(model)
|
| 472 |
+
requested_attn_implementation, effective_global_attn_implementation, effective_local_attn_implementation = (
|
| 473 |
+
self._resolve_request_attention_implementation(attn_implementation)
|
| 474 |
+
)
|
| 475 |
+
audio_tokenizer = self._load_audio_tokenizer_locked(
|
| 476 |
+
tts_attn_implementation=effective_global_attn_implementation
|
| 477 |
+
)
|
| 478 |
+
self._apply_model_attention_implementation(
|
| 479 |
+
model,
|
| 480 |
+
global_attn=effective_global_attn_implementation,
|
| 481 |
+
local_attn=effective_local_attn_implementation,
|
| 482 |
+
)
|
| 483 |
+
if seed is not None:
|
| 484 |
+
torch.manual_seed(seed)
|
| 485 |
+
if torch.cuda.is_available():
|
| 486 |
+
torch.cuda.manual_seed_all(seed)
|
| 487 |
+
|
| 488 |
+
try:
|
| 489 |
+
result = model.inference(
|
| 490 |
+
text=normalized_text,
|
| 491 |
+
output_audio_path=str(output_path),
|
| 492 |
+
mode=normalized_mode,
|
| 493 |
+
prompt_text=prompt_text,
|
| 494 |
+
prompt_audio_path=None if effective_prompt_audio_path is None else str(effective_prompt_audio_path),
|
| 495 |
+
text_tokenizer_path=str(self.checkpoint_path),
|
| 496 |
+
audio_tokenizer=audio_tokenizer,
|
| 497 |
+
device=self.device,
|
| 498 |
+
nq=nq,
|
| 499 |
+
max_new_frames=int(max_new_frames),
|
| 500 |
+
voice_clone_max_text_tokens=int(voice_clone_max_text_tokens),
|
| 501 |
+
voice_clone_max_memory_per_sample_gb=float(voice_clone_max_memory_per_sample_gb),
|
| 502 |
+
tts_max_batch_size=int(tts_max_batch_size),
|
| 503 |
+
codec_max_batch_size=int(codec_max_batch_size),
|
| 504 |
+
do_sample=bool(do_sample),
|
| 505 |
+
use_kv_cache=True,
|
| 506 |
+
text_temperature=float(text_temperature),
|
| 507 |
+
text_top_p=float(text_top_p),
|
| 508 |
+
text_top_k=int(text_top_k),
|
| 509 |
+
audio_temperature=float(audio_temperature),
|
| 510 |
+
audio_top_p=float(audio_top_p),
|
| 511 |
+
audio_top_k=int(audio_top_k),
|
| 512 |
+
audio_repetition_penalty=float(audio_repetition_penalty),
|
| 513 |
+
)
|
| 514 |
+
except Exception:
|
| 515 |
+
self._discard_loaded_audio_tokenizer_locked(
|
| 516 |
+
"inference failed; reloading audio tokenizer on next request"
|
| 517 |
+
)
|
| 518 |
+
self._discard_loaded_model_locked("inference failed; reloading checkpoint on next request")
|
| 519 |
+
raise
|
| 520 |
+
effective_global_attn_implementation, effective_local_attn_implementation = (
|
| 521 |
+
self._read_model_attention_implementation(model)
|
| 522 |
+
)
|
| 523 |
+
current_parameter = next(model.parameters(), None)
|
| 524 |
+
if current_parameter is not None and current_parameter.dtype != self.dtype:
|
| 525 |
+
self._discard_loaded_model_locked(
|
| 526 |
+
f"inference left model in dtype={current_parameter.dtype}; reloading checkpoint on next request"
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
waveform = result["waveform"].detach().cpu()
|
| 530 |
+
waveform_numpy = waveform_to_numpy(waveform)
|
| 531 |
+
return {
|
| 532 |
+
"audio_path": str(output_path),
|
| 533 |
+
"sample_rate": int(result["sample_rate"]),
|
| 534 |
+
"waveform": waveform,
|
| 535 |
+
"waveform_numpy": waveform_numpy,
|
| 536 |
+
"audio_token_ids": result["audio_token_ids"],
|
| 537 |
+
"reference_audio_token_ids": result["reference_audio_token_ids"],
|
| 538 |
+
"elapsed_seconds": time.monotonic() - started_at,
|
| 539 |
+
"voice": resolved_voice,
|
| 540 |
+
"mode": normalized_mode,
|
| 541 |
+
"prompt_audio_path": None if effective_prompt_audio_path is None else str(effective_prompt_audio_path),
|
| 542 |
+
"requested_attn_implementation": requested_attn_implementation,
|
| 543 |
+
"effective_global_attn_implementation": effective_global_attn_implementation,
|
| 544 |
+
"effective_local_attn_implementation": effective_local_attn_implementation,
|
| 545 |
+
"voice_clone_text_chunks": result.get("voice_clone_text_chunks"),
|
| 546 |
+
"voice_clone_chunk_batch_size": result.get("voice_clone_chunk_batch_size"),
|
| 547 |
+
"voice_clone_codec_batch_size": result.get("voice_clone_codec_batch_size"),
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
def synthesize_stream(
|
| 551 |
+
self,
|
| 552 |
+
*,
|
| 553 |
+
text: str,
|
| 554 |
+
voice: Optional[str] = None,
|
| 555 |
+
mode: str = "voice_clone",
|
| 556 |
+
output_audio_path: Optional[str | Path] = None,
|
| 557 |
+
prompt_audio_path: Optional[str | Path] = None,
|
| 558 |
+
prompt_text: Optional[str] = None,
|
| 559 |
+
max_new_frames: int = 375,
|
| 560 |
+
voice_clone_max_text_tokens: int = 75,
|
| 561 |
+
voice_clone_max_memory_per_sample_gb: float = 1.0,
|
| 562 |
+
tts_max_batch_size: int = 0,
|
| 563 |
+
codec_max_batch_size: int = 0,
|
| 564 |
+
do_sample: bool = True,
|
| 565 |
+
text_temperature: float = 1.0,
|
| 566 |
+
text_top_p: float = 1.0,
|
| 567 |
+
text_top_k: int = 50,
|
| 568 |
+
audio_temperature: float = 0.8,
|
| 569 |
+
audio_top_p: float = 0.95,
|
| 570 |
+
audio_top_k: int = 25,
|
| 571 |
+
audio_repetition_penalty: float = 1.2,
|
| 572 |
+
nq: Optional[int] = None,
|
| 573 |
+
seed: Optional[int] = None,
|
| 574 |
+
attn_implementation: Optional[str] = None,
|
| 575 |
+
) -> Iterator[dict[str, object]]:
|
| 576 |
+
normalized_text = str(text or "").strip()
|
| 577 |
+
if not normalized_text:
|
| 578 |
+
raise ValueError("text is required")
|
| 579 |
+
|
| 580 |
+
normalized_mode = str(mode).strip().lower()
|
| 581 |
+
if normalized_mode not in {"continuation", "voice_clone"}:
|
| 582 |
+
raise ValueError("mode must be either 'continuation' or 'voice_clone'")
|
| 583 |
+
|
| 584 |
+
effective_prompt_audio_path: Optional[Path] = None
|
| 585 |
+
resolved_voice = self.get_voice_preset(voice).name
|
| 586 |
+
if normalized_mode == "voice_clone":
|
| 587 |
+
effective_prompt_audio_path = self.resolve_prompt_audio_path(
|
| 588 |
+
voice=resolved_voice,
|
| 589 |
+
prompt_audio_path=prompt_audio_path,
|
| 590 |
+
)
|
| 591 |
+
elif prompt_audio_path is not None:
|
| 592 |
+
effective_prompt_audio_path = self.resolve_prompt_audio_path(prompt_audio_path=prompt_audio_path)
|
| 593 |
+
if not prompt_text:
|
| 594 |
+
raise ValueError("continuation mode with prompt_audio_path also requires prompt_text")
|
| 595 |
+
|
| 596 |
+
output_path = (
|
| 597 |
+
Path(output_audio_path).expanduser().resolve()
|
| 598 |
+
if output_audio_path is not None
|
| 599 |
+
else self._build_output_path(prefix=f"{resolved_voice}_{normalized_mode}_stream")
|
| 600 |
+
)
|
| 601 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 602 |
+
|
| 603 |
+
started_at = time.monotonic()
|
| 604 |
+
final_result: dict[str, object] | None = None
|
| 605 |
+
with self._lock:
|
| 606 |
+
model = self._load_model_locked()
|
| 607 |
+
model = self._restore_model_execution_state(model)
|
| 608 |
+
requested_attn_implementation, effective_global_attn_implementation, effective_local_attn_implementation = (
|
| 609 |
+
self._resolve_request_attention_implementation(attn_implementation)
|
| 610 |
+
)
|
| 611 |
+
audio_tokenizer = self._load_audio_tokenizer_locked(
|
| 612 |
+
tts_attn_implementation=effective_global_attn_implementation
|
| 613 |
+
)
|
| 614 |
+
self._apply_model_attention_implementation(
|
| 615 |
+
model,
|
| 616 |
+
global_attn=effective_global_attn_implementation,
|
| 617 |
+
local_attn=effective_local_attn_implementation,
|
| 618 |
+
)
|
| 619 |
+
if seed is not None:
|
| 620 |
+
torch.manual_seed(seed)
|
| 621 |
+
if torch.cuda.is_available():
|
| 622 |
+
torch.cuda.manual_seed_all(seed)
|
| 623 |
+
|
| 624 |
+
try:
|
| 625 |
+
for event in model.inference_stream(
|
| 626 |
+
text=normalized_text,
|
| 627 |
+
output_audio_path=str(output_path),
|
| 628 |
+
mode=normalized_mode,
|
| 629 |
+
prompt_text=prompt_text,
|
| 630 |
+
prompt_audio_path=None if effective_prompt_audio_path is None else str(effective_prompt_audio_path),
|
| 631 |
+
text_tokenizer_path=str(self.checkpoint_path),
|
| 632 |
+
audio_tokenizer=audio_tokenizer,
|
| 633 |
+
device=self.device,
|
| 634 |
+
nq=nq,
|
| 635 |
+
max_new_frames=int(max_new_frames),
|
| 636 |
+
voice_clone_max_text_tokens=int(voice_clone_max_text_tokens),
|
| 637 |
+
voice_clone_max_memory_per_sample_gb=float(voice_clone_max_memory_per_sample_gb),
|
| 638 |
+
tts_max_batch_size=int(tts_max_batch_size),
|
| 639 |
+
codec_max_batch_size=int(codec_max_batch_size),
|
| 640 |
+
do_sample=bool(do_sample),
|
| 641 |
+
use_kv_cache=True,
|
| 642 |
+
text_temperature=float(text_temperature),
|
| 643 |
+
text_top_p=float(text_top_p),
|
| 644 |
+
text_top_k=int(text_top_k),
|
| 645 |
+
audio_temperature=float(audio_temperature),
|
| 646 |
+
audio_top_p=float(audio_top_p),
|
| 647 |
+
audio_top_k=int(audio_top_k),
|
| 648 |
+
audio_repetition_penalty=float(audio_repetition_penalty),
|
| 649 |
+
):
|
| 650 |
+
event_type = str(event.get("type", ""))
|
| 651 |
+
if event_type == "audio":
|
| 652 |
+
waveform = torch.as_tensor(event["waveform"], dtype=torch.float32).cpu()
|
| 653 |
+
yield {
|
| 654 |
+
"type": "audio",
|
| 655 |
+
"waveform": waveform,
|
| 656 |
+
"waveform_numpy": waveform_to_numpy(waveform),
|
| 657 |
+
"sample_rate": int(event["sample_rate"]),
|
| 658 |
+
"chunk_index": int(event.get("chunk_index", 0)),
|
| 659 |
+
"is_pause": bool(event.get("is_pause", False)),
|
| 660 |
+
"emitted_audio_seconds": float(event.get("emitted_audio_seconds", 0.0)),
|
| 661 |
+
"lead_seconds": float(event.get("lead_seconds", 0.0)),
|
| 662 |
+
}
|
| 663 |
+
continue
|
| 664 |
+
if event_type == "result":
|
| 665 |
+
final_result = dict(event)
|
| 666 |
+
except Exception:
|
| 667 |
+
self._discard_loaded_audio_tokenizer_locked(
|
| 668 |
+
"streaming inference failed; reloading audio tokenizer on next request"
|
| 669 |
+
)
|
| 670 |
+
self._discard_loaded_model_locked("streaming inference failed; reloading checkpoint on next request")
|
| 671 |
+
raise
|
| 672 |
+
|
| 673 |
+
effective_global_attn_implementation, effective_local_attn_implementation = (
|
| 674 |
+
self._read_model_attention_implementation(model)
|
| 675 |
+
)
|
| 676 |
+
current_parameter = next(model.parameters(), None)
|
| 677 |
+
if current_parameter is not None and current_parameter.dtype != self.dtype:
|
| 678 |
+
self._discard_loaded_model_locked(
|
| 679 |
+
f"streaming inference left model in dtype={current_parameter.dtype}; reloading checkpoint on next request"
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
if final_result is None:
|
| 683 |
+
raise RuntimeError("Streaming synthesis finished without a final result.")
|
| 684 |
+
|
| 685 |
+
waveform = torch.as_tensor(final_result["waveform"], dtype=torch.float32).cpu()
|
| 686 |
+
yield {
|
| 687 |
+
"type": "result",
|
| 688 |
+
"audio_path": str(final_result["audio_path"]),
|
| 689 |
+
"sample_rate": int(final_result["sample_rate"]),
|
| 690 |
+
"waveform": waveform,
|
| 691 |
+
"waveform_numpy": waveform_to_numpy(waveform),
|
| 692 |
+
"audio_token_ids": final_result["audio_token_ids"],
|
| 693 |
+
"reference_audio_token_ids": final_result["reference_audio_token_ids"],
|
| 694 |
+
"elapsed_seconds": time.monotonic() - started_at,
|
| 695 |
+
"voice": resolved_voice,
|
| 696 |
+
"mode": normalized_mode,
|
| 697 |
+
"prompt_audio_path": None if effective_prompt_audio_path is None else str(effective_prompt_audio_path),
|
| 698 |
+
"requested_attn_implementation": requested_attn_implementation,
|
| 699 |
+
"effective_global_attn_implementation": effective_global_attn_implementation,
|
| 700 |
+
"effective_local_attn_implementation": effective_local_attn_implementation,
|
| 701 |
+
"voice_clone_text_chunks": final_result.get("voice_clone_text_chunks"),
|
| 702 |
+
"voice_clone_chunk_batch_size": final_result.get("voice_clone_chunk_batch_size"),
|
| 703 |
+
"voice_clone_codec_batch_size": final_result.get("voice_clone_codec_batch_size"),
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
def warmup(
|
| 707 |
+
self,
|
| 708 |
+
*,
|
| 709 |
+
text: str = "你好,欢迎使用 Nano-TTS。",
|
| 710 |
+
voice: Optional[str] = None,
|
| 711 |
+
) -> dict[str, object]:
|
| 712 |
+
return self.synthesize(
|
| 713 |
+
text=text,
|
| 714 |
+
voice=voice or self.default_voice,
|
| 715 |
+
mode="voice_clone",
|
| 716 |
+
output_audio_path=self.output_dir / "_warmup" / "warmup.wav",
|
| 717 |
+
max_new_frames=96,
|
| 718 |
+
voice_clone_max_text_tokens=75,
|
| 719 |
+
do_sample=False,
|
| 720 |
+
text_temperature=1.0,
|
| 721 |
+
text_top_p=1.0,
|
| 722 |
+
text_top_k=50,
|
| 723 |
+
audio_temperature=0.8,
|
| 724 |
+
audio_top_p=0.95,
|
| 725 |
+
audio_top_k=25,
|
| 726 |
+
audio_repetition_penalty=1.0,
|
| 727 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=1.24
|
| 2 |
+
sentencepiece>=0.1.99
|
| 3 |
+
torch==2.7.0
|
| 4 |
+
torchaudio==2.7.0
|
| 5 |
+
transformers==4.57.1
|
| 6 |
+
safetensors>=0.4.3
|
| 7 |
+
gradio==6.5.1
|
text_normalization_pipeline.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import re
|
| 5 |
+
import threading
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
from tts_robust_normalizer_single_script import normalize_tts_text
|
| 9 |
+
|
| 10 |
+
ENGLISH_VOICES = frozenset({"Trump", "Ava", "Bella", "Adam", "Nathan"})
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class TextNormalizationSnapshot:
|
| 15 |
+
state: str
|
| 16 |
+
message: str
|
| 17 |
+
error: str | None = None
|
| 18 |
+
ready: bool = False
|
| 19 |
+
available: bool = False
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def failed(self) -> bool:
|
| 23 |
+
return self.state == "failed"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class WeTextProcessingManager:
|
| 27 |
+
def __init__(self) -> None:
|
| 28 |
+
self._lock = threading.Lock()
|
| 29 |
+
self._normalize_lock = threading.Lock()
|
| 30 |
+
self._thread: threading.Thread | None = None
|
| 31 |
+
self._started = False
|
| 32 |
+
self._state = "pending"
|
| 33 |
+
self._message = "Waiting for WeTextProcessing preload."
|
| 34 |
+
self._error: str | None = None
|
| 35 |
+
self._available = True
|
| 36 |
+
self._normalizers: dict[str, object] | None = None
|
| 37 |
+
|
| 38 |
+
def snapshot(self) -> TextNormalizationSnapshot:
|
| 39 |
+
with self._lock:
|
| 40 |
+
return TextNormalizationSnapshot(
|
| 41 |
+
state=self._state,
|
| 42 |
+
message=self._message,
|
| 43 |
+
error=self._error,
|
| 44 |
+
ready=self._state == "ready",
|
| 45 |
+
available=self._available,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def _set_state(self, *, state: str, message: str, error: str | None = None) -> None:
|
| 49 |
+
with self._lock:
|
| 50 |
+
self._state = state
|
| 51 |
+
self._message = message
|
| 52 |
+
self._error = error
|
| 53 |
+
|
| 54 |
+
def start(self) -> None:
|
| 55 |
+
with self._lock:
|
| 56 |
+
if self._started:
|
| 57 |
+
return
|
| 58 |
+
self._started = True
|
| 59 |
+
self._thread = threading.Thread(target=self._run, name="wetext-preload", daemon=True)
|
| 60 |
+
self._thread.start()
|
| 61 |
+
|
| 62 |
+
def ensure_ready(self) -> TextNormalizationSnapshot:
|
| 63 |
+
with self._lock:
|
| 64 |
+
if not self._started:
|
| 65 |
+
self._started = True
|
| 66 |
+
self._thread = threading.Thread(target=self._run, name="wetext-preload", daemon=True)
|
| 67 |
+
self._thread.start()
|
| 68 |
+
thread = self._thread
|
| 69 |
+
if thread is not None and thread.is_alive():
|
| 70 |
+
thread.join()
|
| 71 |
+
return self.snapshot()
|
| 72 |
+
|
| 73 |
+
def close(self) -> None:
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
def _run(self) -> None:
|
| 77 |
+
if not self._available:
|
| 78 |
+
self._set_state(
|
| 79 |
+
state="failed",
|
| 80 |
+
message="WeTextProcessing unavailable.",
|
| 81 |
+
error="installed WeTextProcessing modules are unavailable",
|
| 82 |
+
)
|
| 83 |
+
return
|
| 84 |
+
try:
|
| 85 |
+
self._set_state(state="running", message="Loading WeTextProcessing graphs.", error=None)
|
| 86 |
+
self._ensure_normalizers_loaded()
|
| 87 |
+
self._set_state(state="ready", message="WeTextProcessing ready. languages=zh,en", error=None)
|
| 88 |
+
except Exception as exc:
|
| 89 |
+
logging.exception("WeTextProcessing preload failed")
|
| 90 |
+
self._set_state(state="failed", message="WeTextProcessing preload failed.", error=str(exc))
|
| 91 |
+
|
| 92 |
+
def _ensure_normalizers_loaded(self) -> dict[str, object]:
|
| 93 |
+
with self._lock:
|
| 94 |
+
if self._normalizers is not None:
|
| 95 |
+
return self._normalizers
|
| 96 |
+
|
| 97 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
| 98 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
| 99 |
+
|
| 100 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 101 |
+
self._normalizers = {
|
| 102 |
+
"zh": ZhNormalizer(overwrite_cache=False),
|
| 103 |
+
"en": EnNormalizer(overwrite_cache=False),
|
| 104 |
+
}
|
| 105 |
+
return self._normalizers
|
| 106 |
+
|
| 107 |
+
def normalize(self, *, text: str, prompt_text: str, language: str) -> tuple[str, str]:
|
| 108 |
+
snapshot = self.ensure_ready()
|
| 109 |
+
if not snapshot.ready:
|
| 110 |
+
raise RuntimeError(snapshot.error or snapshot.message)
|
| 111 |
+
|
| 112 |
+
with self._normalize_lock:
|
| 113 |
+
normalizers = self._ensure_normalizers_loaded()
|
| 114 |
+
if language not in normalizers:
|
| 115 |
+
raise ValueError(f"Unsupported text normalization language: {language}")
|
| 116 |
+
normalizer = normalizers[language]
|
| 117 |
+
normalized_text = normalizer.normalize(text) if text else ""
|
| 118 |
+
normalized_prompt_text = normalizer.normalize(prompt_text) if prompt_text else ""
|
| 119 |
+
return normalized_text, normalized_prompt_text
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def resolve_text_normalization_language(*, text: str, voice: str) -> str:
|
| 123 |
+
if re.search(r"[\u3400-\u9fff]", text):
|
| 124 |
+
return "zh"
|
| 125 |
+
if re.search(r"[A-Za-z]", text):
|
| 126 |
+
return "en"
|
| 127 |
+
if voice in ENGLISH_VOICES:
|
| 128 |
+
return "en"
|
| 129 |
+
return "zh"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def prepare_tts_request_texts(
|
| 133 |
+
*,
|
| 134 |
+
text: str,
|
| 135 |
+
prompt_text: str,
|
| 136 |
+
voice: str,
|
| 137 |
+
enable_wetext: bool,
|
| 138 |
+
text_normalizer_manager: WeTextProcessingManager | None,
|
| 139 |
+
) -> dict[str, object]:
|
| 140 |
+
raw_text = str(text or "")
|
| 141 |
+
raw_prompt_text = str(prompt_text or "")
|
| 142 |
+
|
| 143 |
+
normalization_language = ""
|
| 144 |
+
intermediate_text = raw_text
|
| 145 |
+
intermediate_prompt_text = raw_prompt_text
|
| 146 |
+
|
| 147 |
+
if enable_wetext:
|
| 148 |
+
if text_normalizer_manager is None:
|
| 149 |
+
raise RuntimeError("WeTextProcessing manager is unavailable.")
|
| 150 |
+
normalization_language = resolve_text_normalization_language(text=raw_text, voice=voice)
|
| 151 |
+
intermediate_text, intermediate_prompt_text = text_normalizer_manager.normalize(
|
| 152 |
+
text=raw_text,
|
| 153 |
+
prompt_text=raw_prompt_text,
|
| 154 |
+
language=normalization_language,
|
| 155 |
+
)
|
| 156 |
+
if intermediate_text != raw_text:
|
| 157 |
+
logging.info(
|
| 158 |
+
"normalized text chars_before=%d chars_after=%d stage=wetext language=%s",
|
| 159 |
+
len(raw_text),
|
| 160 |
+
len(intermediate_text),
|
| 161 |
+
normalization_language,
|
| 162 |
+
)
|
| 163 |
+
if raw_prompt_text and intermediate_prompt_text != raw_prompt_text:
|
| 164 |
+
logging.info(
|
| 165 |
+
"normalized prompt_text chars_before=%d chars_after=%d stage=wetext language=%s",
|
| 166 |
+
len(raw_prompt_text),
|
| 167 |
+
len(intermediate_prompt_text),
|
| 168 |
+
normalization_language,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
final_text = normalize_tts_text(intermediate_text)
|
| 172 |
+
final_prompt_text = normalize_tts_text(intermediate_prompt_text) if intermediate_prompt_text else ""
|
| 173 |
+
|
| 174 |
+
if final_text != intermediate_text:
|
| 175 |
+
logging.info(
|
| 176 |
+
"normalized text chars_before=%d chars_after=%d stage=robust_final",
|
| 177 |
+
len(intermediate_text),
|
| 178 |
+
len(final_text),
|
| 179 |
+
)
|
| 180 |
+
if intermediate_prompt_text and final_prompt_text != intermediate_prompt_text:
|
| 181 |
+
logging.info(
|
| 182 |
+
"normalized prompt_text chars_before=%d chars_after=%d stage=robust_final",
|
| 183 |
+
len(intermediate_prompt_text),
|
| 184 |
+
len(final_prompt_text),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"text": final_text,
|
| 189 |
+
"prompt_text": final_prompt_text,
|
| 190 |
+
"normalized_text": final_text,
|
| 191 |
+
"normalized_prompt_text": final_prompt_text,
|
| 192 |
+
"normalization_method": (f"wetext:{normalization_language}+robust" if enable_wetext else "robust"),
|
| 193 |
+
"text_normalization_language": normalization_language,
|
| 194 |
+
"text_normalization_enabled": bool(enable_wetext),
|
| 195 |
+
}
|
tts_robust_normalizer_single_script.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
TTS 输入鲁棒性正则化器(非语义 TN)
|
| 6 |
+
|
| 7 |
+
目标
|
| 8 |
+
----
|
| 9 |
+
1. 只做“鲁棒性清洗”,不做数字/单位/日期/金额等语义展开。
|
| 10 |
+
2. 优先保护高风险 token,避免把 `.map`、`app.js.map`、`v2.3.1`、URL、Email、@mention、#hashtag 清坏。
|
| 11 |
+
3. `[]` / `{}` / `【】` / `〖〗` / `『』` / `「」` 统一转成双引号包裹内容。
|
| 12 |
+
4. 对结构性符号做“替换而非删除”:
|
| 13 |
+
- `【】 / 〖〗 / 『』 / 「」` 转成双引号包裹内容。
|
| 14 |
+
- `《》` 只在“独立标题/栏目名”场景拆开;嵌入式标题保持不变。
|
| 15 |
+
- `—— / -- / ——...` 转成句边界。
|
| 16 |
+
5. 对社交平台常见噪声做弱归一化:
|
| 17 |
+
- `...... / ……` -> `。`
|
| 18 |
+
- `???!!!` -> `?!`
|
| 19 |
+
- `!!!` -> `!`
|
| 20 |
+
6. 空格按脚本类型处理:
|
| 21 |
+
- 西文片段内部:连续空格压缩为 1 个。
|
| 22 |
+
- 汉字 / 日文假名片段内部:删除空格。
|
| 23 |
+
- 汉字 / 日文假名 与“拉丁字母类 token / 受保护 token”相邻:保留或补 1 个空格。
|
| 24 |
+
- 汉字 / 日文假名 与纯数字相邻:不强行补空格。
|
| 25 |
+
7. 轻量处理 Markdown 与换行:
|
| 26 |
+
- `[text](url)` -> `text url`
|
| 27 |
+
- 去掉标题 `#`、引用 `>`、列表前缀
|
| 28 |
+
- 换行转句边界 `。`
|
| 29 |
+
|
| 30 |
+
非目标
|
| 31 |
+
------
|
| 32 |
+
1. 不决定“应该怎么读”。
|
| 33 |
+
2. 不做 HTML/SSML/语义标签解释。
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
|
| 38 |
+
import re
|
| 39 |
+
import unicodedata
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ---------------------------
|
| 43 |
+
# 基础常量与正则
|
| 44 |
+
# ---------------------------
|
| 45 |
+
|
| 46 |
+
# 不依赖空格分词的脚本:汉字 + 日文假名
|
| 47 |
+
_CJK_CHARS = r"\u3400-\u4dbf\u4e00-\u9fff\u3040-\u30ff"
|
| 48 |
+
_CJK = f"[{_CJK_CHARS}]"
|
| 49 |
+
|
| 50 |
+
# 保护占位符
|
| 51 |
+
_PROT = r"___PROT\d+___"
|
| 52 |
+
|
| 53 |
+
# 需要保护的高风险 token
|
| 54 |
+
_URL_RE = re.compile(r"https?://[^\s\u3000,。!?;、)】》〉」』]+")
|
| 55 |
+
_EMAIL_RE = re.compile(r"(?<![\w.+-])[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}(?![\w.-])")
|
| 56 |
+
_MENTION_RE = re.compile(r"(?<![A-Za-z0-9_])@[A-Za-z0-9_]{1,32}")
|
| 57 |
+
_REDDIT_RE = re.compile(r"(?<![A-Za-z0-9_])(?:u|r)/[A-Za-z0-9_]+")
|
| 58 |
+
_HASHTAG_RE = re.compile(r"(?<![A-Za-z0-9_])#(?!\s)[^\s#]+")
|
| 59 |
+
|
| 60 |
+
# `.map` / `.env` / `.gitignore`
|
| 61 |
+
_DOT_TOKEN_RE = re.compile(r"(?<![A-Za-z0-9_])\.(?=[A-Za-z0-9._-]*[A-Za-z0-9])[A-Za-z0-9._-]+")
|
| 62 |
+
|
| 63 |
+
# `app.js.map` / `index.d.ts` / `v2.3.1` / `foo/bar-baz.py` 等
|
| 64 |
+
_FILELIKE_RE = re.compile(
|
| 65 |
+
r"(?<![A-Za-z0-9_])"
|
| 66 |
+
r"(?=[A-Za-z0-9._/+:-]*[A-Za-z])"
|
| 67 |
+
r"(?=[A-Za-z0-9._/+:-]*[._/+:-])"
|
| 68 |
+
r"[A-Za-z0-9][A-Za-z0-9._/+:-]*"
|
| 69 |
+
r"(?![A-Za-z0-9_])"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# 参与“中英混排边界补空格”的 token:必须至少含 1 个拉丁字母,或本身就是受保护 token
|
| 73 |
+
_LATINISH = rf"(?:{_PROT}|(?=[A-Za-z0-9._/+:-]*[A-Za-z])[A-Za-z0-9][A-Za-z0-9._/+:-]*)"
|
| 74 |
+
|
| 75 |
+
# 零宽字符
|
| 76 |
+
_ZERO_WIDTH_RE = re.compile(r"[\u200b-\u200d\ufeff]")
|
| 77 |
+
_TRAILING_CLOSERS = set('"\')]})】》〉」』”’')
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ---------------------------
|
| 81 |
+
# 主函数
|
| 82 |
+
# ---------------------------
|
| 83 |
+
|
| 84 |
+
def normalize_tts_text(text: str) -> str:
|
| 85 |
+
"""对 TTS 输入做鲁棒性正则化。"""
|
| 86 |
+
text = _base_cleanup(text)
|
| 87 |
+
text = _normalize_markdown_and_lines(text)
|
| 88 |
+
text, protected = _protect_spans(text)
|
| 89 |
+
|
| 90 |
+
text = _normalize_spaces(text)
|
| 91 |
+
text = _normalize_structural_punctuation(text)
|
| 92 |
+
text = _normalize_repeated_punctuation(text)
|
| 93 |
+
text = _normalize_spaces(text)
|
| 94 |
+
|
| 95 |
+
text = _restore_spans(text, protected)
|
| 96 |
+
text = text.strip()
|
| 97 |
+
return _ensure_terminal_punctuation_by_line(text)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ---------------------------
|
| 101 |
+
# 具体规则
|
| 102 |
+
# ---------------------------
|
| 103 |
+
|
| 104 |
+
def _base_cleanup(text: str) -> str:
|
| 105 |
+
text = text.replace("\r\n", "\n").replace("\r", "\n").replace("\u3000", " ")
|
| 106 |
+
text = _ZERO_WIDTH_RE.sub("", text)
|
| 107 |
+
|
| 108 |
+
cleaned = []
|
| 109 |
+
for ch in text:
|
| 110 |
+
cat = unicodedata.category(ch)
|
| 111 |
+
if ch in "\n\t " or not cat.startswith("C"):
|
| 112 |
+
cleaned.append(ch)
|
| 113 |
+
return "".join(cleaned)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _normalize_markdown_and_lines(text: str) -> str:
|
| 117 |
+
# Markdown 链接:[text](url) -> text url
|
| 118 |
+
text = re.sub(r"\[([^\[\]]+?)\]\((https?://[^)\s]+)\)", r"\1 \2", text)
|
| 119 |
+
|
| 120 |
+
lines = []
|
| 121 |
+
for raw in text.splitlines():
|
| 122 |
+
line = raw.strip()
|
| 123 |
+
if not line:
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
line = re.sub(r"^#{1,6}\s+", "", line) # 标题
|
| 127 |
+
line = re.sub(r"^>\s+", "", line) # 引用
|
| 128 |
+
line = re.sub(r"^[-*+]\s+", "", line) # 无序列表
|
| 129 |
+
line = re.sub(r"^\d+[.)]\s+", "", line) # 有序列表
|
| 130 |
+
lines.append(line)
|
| 131 |
+
|
| 132 |
+
if not lines:
|
| 133 |
+
return ""
|
| 134 |
+
|
| 135 |
+
merged: list[str] = [lines[0]]
|
| 136 |
+
for line in lines[1:]:
|
| 137 |
+
previous = merged[-1]
|
| 138 |
+
merged[-1] = _ensure_terminal_punctuation(previous)
|
| 139 |
+
merged.append(line)
|
| 140 |
+
return "".join(merged)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _protect_spans(text: str) -> tuple[str, list[str]]:
|
| 144 |
+
protected: list[str] = []
|
| 145 |
+
|
| 146 |
+
def repl(match: re.Match[str]) -> str:
|
| 147 |
+
idx = len(protected)
|
| 148 |
+
protected.append(match.group(0))
|
| 149 |
+
return f"___PROT{idx}___"
|
| 150 |
+
|
| 151 |
+
for pattern in (
|
| 152 |
+
_URL_RE,
|
| 153 |
+
_EMAIL_RE,
|
| 154 |
+
_MENTION_RE,
|
| 155 |
+
_REDDIT_RE,
|
| 156 |
+
_HASHTAG_RE,
|
| 157 |
+
_DOT_TOKEN_RE,
|
| 158 |
+
_FILELIKE_RE,
|
| 159 |
+
):
|
| 160 |
+
text = pattern.sub(repl, text)
|
| 161 |
+
|
| 162 |
+
return text, protected
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _restore_spans(text: str, protected: list[str]) -> str:
|
| 166 |
+
for idx, original in enumerate(protected):
|
| 167 |
+
text = text.replace(f"___PROT{idx}___", original)
|
| 168 |
+
return text
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _normalize_spaces(text: str) -> str:
|
| 172 |
+
# 统一空白
|
| 173 |
+
text = re.sub(r"[ \t\r\f\v]+", " ", text)
|
| 174 |
+
|
| 175 |
+
# 汉字 / 日文片段内部:删除空格
|
| 176 |
+
text = re.sub(rf"({_CJK})\s+(?={_CJK})", r"\1", text)
|
| 177 |
+
|
| 178 |
+
# 汉字 / 日文 与纯数字之间:删除空格(不强行保留)
|
| 179 |
+
text = re.sub(rf"({_CJK})\s+(?=\d)", r"\1", text)
|
| 180 |
+
text = re.sub(rf"(\d)\s+(?={_CJK})", r"\1", text)
|
| 181 |
+
|
| 182 |
+
# 汉字 / 日文 与拉丁字母类 token / protected token 相邻:保留或补 1 个空格
|
| 183 |
+
text = re.sub(rf"({_CJK})(?=({_LATINISH}))", r"\1 ", text)
|
| 184 |
+
text = re.sub(rf"(({_LATINISH}))(?={_CJK})", r"\1 ", text)
|
| 185 |
+
|
| 186 |
+
# 再压一遍连续空格
|
| 187 |
+
text = re.sub(r" {2,}", " ", text)
|
| 188 |
+
|
| 189 |
+
# 中文标点前后不保留空格
|
| 190 |
+
text = re.sub(r"\s+([,。!?;:、”’」』】)》])", r"\1", text)
|
| 191 |
+
text = re.sub(r"([(【「『《“‘])\s+", r"\1", text)
|
| 192 |
+
text = re.sub(r"([,。!?;:、])\s*", r"\1", text)
|
| 193 |
+
|
| 194 |
+
# ASCII 标点前不留空格;后面的英文空格不强改
|
| 195 |
+
text = re.sub(r"\s+([,.;!?])", r"\1", text)
|
| 196 |
+
|
| 197 |
+
return re.sub(r" {2,}", " ", text).strip()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _normalize_structural_punctuation(text: str) -> str:
|
| 201 |
+
# 各类结构性括号:统一转成双引号包裹内容
|
| 202 |
+
text = re.sub(r"\[\s*([^\[\]]+?)\s*\]", r'"\1"', text)
|
| 203 |
+
text = re.sub(r"\{\s*([^{}]+?)\s*\}", r'"\1"', text)
|
| 204 |
+
text = re.sub(r"[【〖『「]\s*([^】〗』」]+?)\s*[】〗』」]", r'"\1"', text)
|
| 205 |
+
|
| 206 |
+
# 《》只处理独立标题,不处理嵌入式标题
|
| 207 |
+
# 例:重磅。《新品发布》——现在开始! -> 重磅。新品发布。现在开始!
|
| 208 |
+
text = re.sub(
|
| 209 |
+
r"(^|[。!?!?;;]\s*)《([^》]+)》(?=\s*(?:___PROT\d+___|[—–―-]{2,}|$|[。!?!?;;,,]))",
|
| 210 |
+
r"\1\2",
|
| 211 |
+
text,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# 长破折号 / 多连字符:转句边界
|
| 215 |
+
text = re.sub(r"\s*(?:—|–|―|-){2,}\s*", "。", text)
|
| 216 |
+
|
| 217 |
+
return text
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _normalize_repeated_punctuation(text: str) -> str:
|
| 221 |
+
# 省略号 / 连续句点
|
| 222 |
+
text = re.sub(r"(?:\.{3,}|…{2,}|……+)", "。", text)
|
| 223 |
+
|
| 224 |
+
# 同类重复标点
|
| 225 |
+
text = re.sub(r"[。.]{2,}", "。", text)
|
| 226 |
+
text = re.sub(r"[,,]{2,}", ",", text)
|
| 227 |
+
text = re.sub(r"[!!]{2,}", "!", text)
|
| 228 |
+
text = re.sub(r"[??]{2,}", "?", text)
|
| 229 |
+
|
| 230 |
+
# 混合问叹号:收敛到 ?!
|
| 231 |
+
def _mixed_qe(match: re.Match[str]) -> str:
|
| 232 |
+
s = match.group(0)
|
| 233 |
+
has_q = any(ch in s for ch in "??")
|
| 234 |
+
has_e = any(ch in s for ch in "!!")
|
| 235 |
+
if has_q and has_e:
|
| 236 |
+
return "?!"
|
| 237 |
+
return "?" if has_q else "!"
|
| 238 |
+
|
| 239 |
+
text = re.sub(r"[!?!?]{2,}", _mixed_qe, text)
|
| 240 |
+
return text
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _ensure_terminal_punctuation(text: str) -> str:
|
| 244 |
+
if not text:
|
| 245 |
+
return text
|
| 246 |
+
|
| 247 |
+
index = len(text) - 1
|
| 248 |
+
while index >= 0 and text[index].isspace():
|
| 249 |
+
index -= 1
|
| 250 |
+
while index >= 0 and text[index] in _TRAILING_CLOSERS:
|
| 251 |
+
index -= 1
|
| 252 |
+
|
| 253 |
+
if index >= 0 and unicodedata.category(text[index]).startswith("P"):
|
| 254 |
+
return text
|
| 255 |
+
return text + "。"
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _ensure_terminal_punctuation_by_line(text: str) -> str:
|
| 259 |
+
if not text:
|
| 260 |
+
return text
|
| 261 |
+
lines = text.split("\n")
|
| 262 |
+
normalized_lines = [_ensure_terminal_punctuation(line.strip()) if line.strip() else "" for line in lines]
|
| 263 |
+
return "\n".join(normalized_lines).strip()
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ---------------------------
|
| 267 |
+
# 测试
|
| 268 |
+
# ---------------------------
|
| 269 |
+
|
| 270 |
+
TEST_CASES = [
|
| 271 |
+
# 1) .map / dot-leading token / 文件名 / 版本号
|
| 272 |
+
(
|
| 273 |
+
"dot_map_sentence",
|
| 274 |
+
"2026 年 3 月 31 日,安全研究员 Chaofan Shou (@Fried_rice) 发现 Anthropic 的 npm 包中暴露了 .map 文件,",
|
| 275 |
+
"2026年3月31日,安全研究员 Chaofan Shou (@Fried_rice) 发现 Anthropic 的 npm 包中暴露了 .map 文件,",
|
| 276 |
+
),
|
| 277 |
+
("dot_tokens", "别把 .env、.npmrc、.gitignore 提交上去。", "别把 .env、.npmrc、.gitignore 提交上去。"),
|
| 278 |
+
("file_names", "请检查 bundle.min.js、package.json 和 processing_moss_tts.py。", "请检查 bundle.min.js、package.json 和 processing_moss_tts.py。"),
|
| 279 |
+
("index_d_ts", "index.d.ts 里也有同样的问题。", "index.d.ts 里也有同样的问题。"),
|
| 280 |
+
("version_build", "Bug 的讨论可以精确到 v2.3.1 (Build 15)。", "Bug 的讨论可以精确到 v2.3.1 (Build 15)。"),
|
| 281 |
+
("version_rc", "3.0.0-rc.1 还不能上生产。", "3.0.0-rc.1 还不能上生产。"),
|
| 282 |
+
("jar_name", "fabric-api-0.91.3+1.20.2.jar 需要单独下载。", "fabric-api-0.91.3+1.20.2.jar 需要单独下载。"),
|
| 283 |
+
|
| 284 |
+
# 2) URL / Email / mention / hashtag / Reddit
|
| 285 |
+
("url", "仓库地址是 https://github.com/instructkr/claude-code", "仓库地址是 https://github.com/instructkr/claude-code。"),
|
| 286 |
+
("email", "联系邮箱:ops+tts@example.ai", "联系邮箱:ops+tts@example.ai。"),
|
| 287 |
+
("mention", "@Fried_rice 说这是 source map 暴露。", "@Fried_rice 说这是 source map 暴露。"),
|
| 288 |
+
("reddit", "去 r/singularity 看讨论。", "去 r/singularity 看讨论。"),
|
| 289 |
+
("hashtag_chain", "#张雪峰#张雪峰[话题]#张雪峰事件", "#张雪峰#张雪峰[话题]#张雪峰事件。"),
|
| 290 |
+
("mention_hashtag_boundary", "关注@biscuit0228_并转发#thetime_tbs", "关注 @biscuit0228_ 并转发 #thetime_tbs。"),
|
| 291 |
+
|
| 292 |
+
# 3) bracket / 控制 token:统一转成双引号
|
| 293 |
+
("speaker_bracket", "[S1]你好。[S2]收到。", '"S1"你好。"S2"收到。'),
|
| 294 |
+
("event_bracket", "请模仿 {whisper} 的语气说“别出声”。", '请模仿 "whisper" 的语气说“别出声”。'),
|
| 295 |
+
("order_bracket", "订单号:[AB-1234-XYZ]", '订单号:"AB-1234-XYZ"。'),
|
| 296 |
+
|
| 297 |
+
# 4) 结构性符号:转成双引号或句边界,而不是直接删除
|
| 298 |
+
("struct_headline", "〖重磅〗《新品发布》——现在开始!", '"重磅"《新品发布》。现在开始!'),
|
| 299 |
+
("struct_notice", "【公告】今天 20:00 维护——预计 30 分钟。", '"公告"今天20:00维护。预计30分钟。'),
|
| 300 |
+
("struct_quote_chain", "『特别提醒』「不要外传」", '"特别提醒""不要外传"。'),
|
| 301 |
+
("struct_embedded_quote", "他说【重要通知】明天发布。", '他说"重要通知"明天发布。'),
|
| 302 |
+
|
| 303 |
+
# 5) 嵌入式标题:保留
|
| 304 |
+
("embedded_title", "我喜欢《哈姆雷特》这本书。", "我喜欢《哈姆雷特》这本书。"),
|
| 305 |
+
|
| 306 |
+
# 6) 重复标点 / 社交噪声
|
| 307 |
+
("noise_qe", "真的假的???!!!", "真的假的?!"),
|
| 308 |
+
("noise_ellipsis", "这个包把 app.js.map 也发上去了......太离谱了!!!", "这个包把 app.js.map 也发上去了。太离谱了!"),
|
| 309 |
+
("noise_ellipsis_cn", "【系统提示】请模仿{sad}低沉语气,说“今天下雨了……”", '"系统提示"请模仿"sad"低沉语气,说“今天下雨了。”'),
|
| 310 |
+
|
| 311 |
+
# 7) 空格规则:英文压缩、中文删除、中英混排保留边界
|
| 312 |
+
("english_spaces", "This is a test.", "This is a test."),
|
| 313 |
+
("chinese_spaces", "这 是 一 段 含有多种空白的文本。", "这是一段含有多种空白的文本。"),
|
| 314 |
+
("mixed_spaces_1", "这是Anthropic的npm包", "这是 Anthropic 的 npm 包。"),
|
| 315 |
+
("mixed_spaces_2", "今天update到v2.3.1了", "今天 update 到 v2.3.1 了。"),
|
| 316 |
+
("mixed_spaces_3", "处理app.js.map文件", "处理 app.js.map 文件。"),
|
| 317 |
+
|
| 318 |
+
# 8) Markdown / 列表 / 换行
|
| 319 |
+
("markdown_link", "详情见 [release note](https://github.com/example/release)", "详情见 release note https://github.com/example/release。"),
|
| 320 |
+
("markdown_heading", "# I made a free open source app to help with markdown files", "I made a free open source app to help with markdown files。"),
|
| 321 |
+
("list_lines", "- 修复 .map 泄露\n- 发布 v2.3.1", "修复 .map 泄露。发布 v2.3.1。"),
|
| 322 |
+
("numbered_lines", "1. 安装依赖\n2. 运行测试\n3. 发布 v2.3.1", "安装依赖。运行测试。发布 v2.3.1。"),
|
| 323 |
+
("newlines", "第一行\n第二行\n第三行", "第一行。第二行。第三行。"),
|
| 324 |
+
|
| 325 |
+
# 9) 句末补标点
|
| 326 |
+
("terminal_punct_plain", "今天发布", "今天发布。"),
|
| 327 |
+
("terminal_punct_quoted", '他说"你好"', '他说"你好"。'),
|
| 328 |
+
("terminal_punct_existing", "今天发布。", "今天发布。"),
|
| 329 |
+
("terminal_punct_newlines", "第一行\n第二行。", "第一行。第二行。"),
|
| 330 |
+
("terminal_punct_blank_lines", "第一行\n\n第二行", "第一行。第二行。"),
|
| 331 |
+
|
| 332 |
+
# 10) 零宽字符 / 幂等性
|
| 333 |
+
("zero_width_url", "详见 https://x.com/\u200bSafety", "详见 https://x.com/Safety。"),
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def run_tests(verbose: bool = True) -> None:
|
| 338 |
+
failed = []
|
| 339 |
+
|
| 340 |
+
for name, text, expected in TEST_CASES:
|
| 341 |
+
actual = normalize_tts_text(text)
|
| 342 |
+
if actual != expected:
|
| 343 |
+
failed.append((name, text, expected, actual))
|
| 344 |
+
continue
|
| 345 |
+
|
| 346 |
+
# 幂等性:第二次归一化不应继续改动结果
|
| 347 |
+
second = normalize_tts_text(actual)
|
| 348 |
+
if second != actual:
|
| 349 |
+
failed.append((name + "_idempotence", actual, actual, second))
|
| 350 |
+
|
| 351 |
+
if failed:
|
| 352 |
+
lines = ["\nTEST FAILED:\n"]
|
| 353 |
+
for name, text, expected, actual in failed:
|
| 354 |
+
lines.append(f"[{name}]")
|
| 355 |
+
lines.append(f"input : {text}")
|
| 356 |
+
lines.append(f"expected: {expected}")
|
| 357 |
+
lines.append(f"actual : {actual}")
|
| 358 |
+
lines.append("")
|
| 359 |
+
raise AssertionError("\n".join(lines))
|
| 360 |
+
|
| 361 |
+
if verbose:
|
| 362 |
+
print(f"All {len(TEST_CASES)} tests passed.")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
run_tests()
|
weights/codec/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
weights/codec/README.md
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: transformers
|
| 4 |
+
tags:
|
| 5 |
+
- audio
|
| 6 |
+
- audio-tokenizer
|
| 7 |
+
- neural-codec
|
| 8 |
+
- moss-tts-family
|
| 9 |
+
- MOSS Audio Tokenizer
|
| 10 |
+
- speech-tokenizer
|
| 11 |
+
- trust-remote-code
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# MossAudioTokenizer
|
| 15 |
+
|
| 16 |
+
This is the code for MOSS-Audio-Tokenizer presented in [MOSS-Audio-Tokenizer: Scaling Audio Tokenizers for Future Audio Foundation Models](https://arxiv.org/abs/2602.10934).
|
| 17 |
+
|
| 18 |
+
**MOSSAudioTokenizer** is a unified discrete audio tokenizer based on the **Cat** (**C**ausal **A**udio **T**okenizer with **T**ransformer) architecture. Scaling to 1.6 billion parameters, it functions as a unified discrete interface, delivering both lossless-quality reconstruction and high-level semantic alignment.
|
| 19 |
+
|
| 20 |
+
**Key Features:**
|
| 21 |
+
|
| 22 |
+
* **Extreme Compression & Variable Bitrate**: It compresses 48kHz stereo audio into a remarkably low frame rate of 12.5Hz. Utilizing a 32-layer Residual LFQ quantizer stack, it supports high-fidelity reconstruction across a wide range of bitrates.
|
| 23 |
+
* **Pure Transformer Architecture**: The model features a "CNN-free" homogeneous architecture built entirely from Causal Transformer blocks. With 1.6B combined parameters (Encoder + Decoder), it ensures exceptional scalability and supports low-latency streaming inference.
|
| 24 |
+
* **Large-Scale General Audio Training**: Trained on 3 million hours of diverse audio data, the model excels at encoding and reconstructing all audio domains, including speech, sound effects, and music.
|
| 25 |
+
* **Unified Semantic-Acoustic Representation**: While achieving state-of-the-art reconstruction quality, Cat produces discrete tokens that are "semantic-rich," making them ideal for downstream tasks like speech understanding (ASR) and generation (TTS).
|
| 26 |
+
* **Fully Trained From Scratch**: Cat does not rely on any pretrained encoders (such as HuBERT or Whisper) or distillation from teacher models. All representations are learned autonomously from raw data.
|
| 27 |
+
* **End-to-End Joint Optimization**: All components—including the encoder, quantizer, decoder, discriminator, and a decoder-only LLM for semantic alignment—are optimized jointly in a single unified training pipeline.
|
| 28 |
+
|
| 29 |
+
**Summary:**
|
| 30 |
+
By combining a simple, scalable architecture with massive-scale data, the Cat architecture overcomes the bottlenecks of traditional audio tokenizers. It provides a robust, high-fidelity, and semantically grounded interface for the next generation of native audio foundation models.
|
| 31 |
+
|
| 32 |
+
This repository contains a lightweight remote-code implementation that mirrors the current 🤗 Transformers
|
| 33 |
+
`transformers.models.moss_audio_tokenizer` module. It is intended to be uploaded to a Hugging Face Hub model repository
|
| 34 |
+
and loaded with `trust_remote_code=True` when needed.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
## Usage
|
| 38 |
+
|
| 39 |
+
### Quickstart
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
import torch
|
| 43 |
+
from transformers import AutoModel
|
| 44 |
+
import torchaudio
|
| 45 |
+
|
| 46 |
+
repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
|
| 47 |
+
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
|
| 48 |
+
|
| 49 |
+
wav, sr = torchaudio.load('demo/demo_gt.wav')
|
| 50 |
+
if sr != model.sampling_rate:
|
| 51 |
+
wav = torchaudio.functional.resample(wav, sr, model.sampling_rate)
|
| 52 |
+
if wav.shape[0] == 1:
|
| 53 |
+
wav = wav.repeat(model.config.number_channels, 1)
|
| 54 |
+
else:
|
| 55 |
+
wav = wav[: model.config.number_channels]
|
| 56 |
+
wav = wav.unsqueeze(0)
|
| 57 |
+
enc = model.encode(wav, return_dict=True)
|
| 58 |
+
print(f"enc.audio_codes.shape: {enc.audio_codes.shape}")
|
| 59 |
+
dec = model.decode(enc.audio_codes, return_dict=True)
|
| 60 |
+
print(f"dec.audio.shape: {dec.audio.shape}")
|
| 61 |
+
wav = dec.audio.squeeze(0)
|
| 62 |
+
torchaudio.save("demo/demo_rec.wav", wav, sample_rate=model.sampling_rate)
|
| 63 |
+
|
| 64 |
+
# Decode using only the first 8 layers of the RVQ
|
| 65 |
+
dec_rvq8 = model.decode(enc.audio_codes[:8], return_dict=True)
|
| 66 |
+
wav_rvq8 = dec_rvq8.audio.squeeze(0)
|
| 67 |
+
torchaudio.save("demo/demo_rec_rvq8.wav", wav_rvq8, sample_rate=model.sampling_rate)
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
### Attention Backend And Compute Dtype
|
| 71 |
+
|
| 72 |
+
`config.attention_implementation` controls whether transformer layers prefer `sdpa` or `flash_attention_2`.
|
| 73 |
+
`config.compute_dtype` controls the non-quantizer autocast dtype and supports `fp32`, `bf16`, and `fp16`.
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
model.set_attention_implementation("flash_attention_2")
|
| 77 |
+
model.set_compute_dtype("fp16")
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
The quantizer always runs in fp32.
|
| 81 |
+
|
| 82 |
+
### Streaming
|
| 83 |
+
|
| 84 |
+
`MossAudioTokenizerModel.encode`, `decode`, `batch_encode`, and `batch_decode` all support streaming through a
|
| 85 |
+
`chunk_duration` argument.
|
| 86 |
+
|
| 87 |
+
- `chunk_duration` is expressed in seconds.
|
| 88 |
+
- `chunk_duration * MossAudioTokenizerConfig.sampling_rate` must be divisible by `MossAudioTokenizerConfig.downsample_rate`.
|
| 89 |
+
- Streaming batch inference is supported.
|
| 90 |
+
- The public waveform interface expects stereo inputs shaped `(2, T)` or batched stereo inputs shaped `(B, 2, T)`.
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
import torch
|
| 94 |
+
from transformers import AutoModel
|
| 95 |
+
|
| 96 |
+
repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
|
| 97 |
+
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
|
| 98 |
+
audio = torch.randn(2, 48000 * 6) # dummy stereo waveform
|
| 99 |
+
|
| 100 |
+
# 6.0s @ 48kHz = 288000 samples, divisible by downsample_rate=3840
|
| 101 |
+
enc = model.encode(audio.unsqueeze(0), return_dict=True, chunk_duration=0.08)
|
| 102 |
+
dec = model.decode(enc.audio_codes, return_dict=True, chunk_duration=0.08)
|
| 103 |
+
|
| 104 |
+
batch_enc = model.batch_encode([audio, audio[:, : 48000 * 3]], chunk_duration=0.08)
|
| 105 |
+
codes_list = [
|
| 106 |
+
batch_enc.audio_codes[:, i, : batch_enc.audio_codes_lengths[i]]
|
| 107 |
+
for i in range(batch_enc.audio_codes.shape[1])
|
| 108 |
+
]
|
| 109 |
+
batch_dec = model.batch_decode(codes_list, chunk_duration=0.08)
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
#### Continuous Batch Streaming Decode
|
| 113 |
+
|
| 114 |
+
For decoder-side continuous batching, prefer `batch_decode(..., streaming=True, ...)`.
|
| 115 |
+
|
| 116 |
+
- The first streaming call may pass `max_batch_size=...`. If it is omitted, the first batch size reserves the
|
| 117 |
+
fixed-slot decoder budget for that public stream.
|
| 118 |
+
- Same-size calls continue the existing logical rows in-order.
|
| 119 |
+
- If a later call is larger, the new rows are admitted by tail append.
|
| 120 |
+
- `finalize_indices` means "decode these rows one last time, then evict them". The indices are interpreted against the
|
| 121 |
+
pre-call logical order.
|
| 122 |
+
- After a finalize call returns, the next streaming call may use the smaller survivor batch.
|
| 123 |
+
- `reset_stream=True` discards the hidden public streaming state and starts a fresh stream.
|
| 124 |
+
|
| 125 |
+
Milestone 1 boundaries:
|
| 126 |
+
|
| 127 |
+
- decode-only continuous batching
|
| 128 |
+
- one active streaming decode state per model instance
|
| 129 |
+
- fixed-slot decoder reservation from `max_batch_size`
|
| 130 |
+
- no encode-side continuous batching
|
| 131 |
+
- no physical compaction of surviving decode slots
|
| 132 |
+
- no multi-session concurrency on one model instance
|
| 133 |
+
|
| 134 |
+
```python
|
| 135 |
+
import torch
|
| 136 |
+
from transformers import AutoModel
|
| 137 |
+
|
| 138 |
+
repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
|
| 139 |
+
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
|
| 140 |
+
num_quantizers = model.config.quantizer_kwargs["num_quantizers"]
|
| 141 |
+
|
| 142 |
+
codes_a0 = torch.randint(0, 8, (num_quantizers, 2))
|
| 143 |
+
codes_b0 = torch.randint(0, 8, (num_quantizers, 3))
|
| 144 |
+
codes_a1 = torch.randint(0, 8, (num_quantizers, 2))
|
| 145 |
+
codes_b1 = torch.randint(0, 8, (num_quantizers, 2))
|
| 146 |
+
codes_c0 = torch.randint(0, 8, (num_quantizers, 1))
|
| 147 |
+
codes_a2 = torch.randint(0, 8, (num_quantizers, 1))
|
| 148 |
+
codes_b2 = torch.randint(0, 8, (num_quantizers, 2))
|
| 149 |
+
codes_c1 = torch.randint(0, 8, (num_quantizers, 2))
|
| 150 |
+
codes_b3 = torch.randint(0, 8, (num_quantizers, 1))
|
| 151 |
+
codes_c2 = torch.randint(0, 8, (num_quantizers, 1))
|
| 152 |
+
|
| 153 |
+
# First call reserves 3 fixed decoder slots for A and B.
|
| 154 |
+
out_ab0 = model.batch_decode(
|
| 155 |
+
[codes_a0, codes_b0],
|
| 156 |
+
streaming=True,
|
| 157 |
+
max_batch_size=3,
|
| 158 |
+
reset_stream=True,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Same logical rows continue in-order; C is a tail append.
|
| 162 |
+
out_abc1 = model.batch_decode(
|
| 163 |
+
[codes_a1, codes_b1, codes_c0],
|
| 164 |
+
streaming=True,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Finalize A against the pre-call logical order. A still decodes in this call,
|
| 168 |
+
# then is evicted immediately afterward.
|
| 169 |
+
out_abc2 = model.batch_decode(
|
| 170 |
+
[codes_a2, codes_b2, codes_c1],
|
| 171 |
+
streaming=True,
|
| 172 |
+
finalize_indices=[0],
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# The next call can shrink to the surviving logical rows only.
|
| 176 |
+
out_bc3 = model.batch_decode(
|
| 177 |
+
[codes_b3, codes_c2],
|
| 178 |
+
streaming=True,
|
| 179 |
+
)
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## Repository layout
|
| 183 |
+
|
| 184 |
+
- `configuration_moss_audio_tokenizer.py`
|
| 185 |
+
- `modeling_moss_audio_tokenizer.py`
|
| 186 |
+
- `__init__.py`
|
| 187 |
+
- `config.json`
|
| 188 |
+
- model weights
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
## Citation
|
| 192 |
+
If you use this code or result in your paper, please cite our work as:
|
| 193 |
+
```tex
|
| 194 |
+
|
| 195 |
+
```
|
weights/codec/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Remote code package for Moss audio tokenizer."""
|
weights/codec/config.json
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MossAudioTokenizerModel"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_moss_audio_tokenizer.MossAudioTokenizerConfig",
|
| 7 |
+
"AutoModel": "modeling_moss_audio_tokenizer.MossAudioTokenizerModel"
|
| 8 |
+
},
|
| 9 |
+
"model_type": "moss-audio-tokenizer",
|
| 10 |
+
"sample_rate": 48000,
|
| 11 |
+
"sampling_rate": 48000,
|
| 12 |
+
"downsample_rate": 3840,
|
| 13 |
+
"causal_transformer_context_duration": 10.0,
|
| 14 |
+
"number_channels": 2,
|
| 15 |
+
"enable_channel_interleave": true,
|
| 16 |
+
"attention_implementation": "sdpa",
|
| 17 |
+
"compute_dtype": "fp32",
|
| 18 |
+
"dtype": "float32",
|
| 19 |
+
"code_dim": 768,
|
| 20 |
+
"encoder_kwargs": [
|
| 21 |
+
{
|
| 22 |
+
"module_type": "PatchedPretransform",
|
| 23 |
+
"patch_size": 240
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"causal": true,
|
| 27 |
+
"context_duration": 4.0,
|
| 28 |
+
"conv_layout": true,
|
| 29 |
+
"d_model": 256,
|
| 30 |
+
"dim_feedforward": 1024,
|
| 31 |
+
"gating": "none",
|
| 32 |
+
"input_dimension": 240,
|
| 33 |
+
"layer_scale": 0.01,
|
| 34 |
+
"max_period": 10000,
|
| 35 |
+
"module_type": "Transformer",
|
| 36 |
+
"norm": "layer_norm",
|
| 37 |
+
"num_heads": 4,
|
| 38 |
+
"num_layers": 4,
|
| 39 |
+
"output_dimension": 384,
|
| 40 |
+
"positional_embedding": "rope"
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"module_type": "PatchedPretransform",
|
| 44 |
+
"patch_size": 2
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"causal": true,
|
| 48 |
+
"context_duration": 6.0,
|
| 49 |
+
"conv_layout": true,
|
| 50 |
+
"d_model": 256,
|
| 51 |
+
"dim_feedforward": 1024,
|
| 52 |
+
"gating": "none",
|
| 53 |
+
"input_dimension": 768,
|
| 54 |
+
"layer_scale": 0.01,
|
| 55 |
+
"max_period": 10000,
|
| 56 |
+
"module_type": "Transformer",
|
| 57 |
+
"norm": "layer_norm",
|
| 58 |
+
"num_heads": 4,
|
| 59 |
+
"num_layers": 2,
|
| 60 |
+
"output_dimension": 384,
|
| 61 |
+
"positional_embedding": "rope"
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"module_type": "PatchedPretransform",
|
| 65 |
+
"patch_size": 2
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"causal": true,
|
| 69 |
+
"context_duration": 8.0,
|
| 70 |
+
"conv_layout": true,
|
| 71 |
+
"d_model": 256,
|
| 72 |
+
"dim_feedforward": 1024,
|
| 73 |
+
"gating": "none",
|
| 74 |
+
"input_dimension": 768,
|
| 75 |
+
"layer_scale": 0.01,
|
| 76 |
+
"max_period": 10000,
|
| 77 |
+
"module_type": "Transformer",
|
| 78 |
+
"norm": "layer_norm",
|
| 79 |
+
"num_heads": 4,
|
| 80 |
+
"num_layers": 2,
|
| 81 |
+
"output_dimension": 384,
|
| 82 |
+
"positional_embedding": "rope"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"module_type": "PatchedPretransform",
|
| 86 |
+
"patch_size": 2
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"causal": true,
|
| 90 |
+
"context_duration": 10.0,
|
| 91 |
+
"conv_layout": true,
|
| 92 |
+
"d_model": 256,
|
| 93 |
+
"dim_feedforward": 1024,
|
| 94 |
+
"gating": "none",
|
| 95 |
+
"input_dimension": 768,
|
| 96 |
+
"layer_scale": 0.01,
|
| 97 |
+
"max_period": 10000,
|
| 98 |
+
"module_type": "Transformer",
|
| 99 |
+
"norm": "layer_norm",
|
| 100 |
+
"num_heads": 4,
|
| 101 |
+
"num_layers": 4,
|
| 102 |
+
"output_dimension": 192,
|
| 103 |
+
"positional_embedding": "rope"
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"module_type": "PatchedPretransform",
|
| 107 |
+
"patch_size": 4
|
| 108 |
+
}
|
| 109 |
+
],
|
| 110 |
+
"decoder_kwargs": [
|
| 111 |
+
{
|
| 112 |
+
"module_type": "PatchedPretransform",
|
| 113 |
+
"patch_size": 4
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"causal": true,
|
| 117 |
+
"context_duration": 10.0,
|
| 118 |
+
"conv_layout": true,
|
| 119 |
+
"d_model": 256,
|
| 120 |
+
"dim_feedforward": 1024,
|
| 121 |
+
"gating": "none",
|
| 122 |
+
"input_dimension": 192,
|
| 123 |
+
"layer_scale": 0.01,
|
| 124 |
+
"max_period": 10000,
|
| 125 |
+
"module_type": "Transformer",
|
| 126 |
+
"norm": "layer_norm",
|
| 127 |
+
"num_heads": 4,
|
| 128 |
+
"num_layers": 4,
|
| 129 |
+
"output_dimension": 768,
|
| 130 |
+
"positional_embedding": "rope"
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"module_type": "PatchedPretransform",
|
| 134 |
+
"patch_size": 2
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"causal": true,
|
| 138 |
+
"context_duration": 8.0,
|
| 139 |
+
"conv_layout": true,
|
| 140 |
+
"d_model": 256,
|
| 141 |
+
"dim_feedforward": 1024,
|
| 142 |
+
"gating": "none",
|
| 143 |
+
"input_dimension": 384,
|
| 144 |
+
"layer_scale": 0.01,
|
| 145 |
+
"max_period": 10000,
|
| 146 |
+
"module_type": "Transformer",
|
| 147 |
+
"norm": "layer_norm",
|
| 148 |
+
"num_heads": 4,
|
| 149 |
+
"num_layers": 2,
|
| 150 |
+
"output_dimension": 768,
|
| 151 |
+
"positional_embedding": "rope"
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"module_type": "PatchedPretransform",
|
| 155 |
+
"patch_size": 2
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"causal": true,
|
| 159 |
+
"context_duration": 6.0,
|
| 160 |
+
"conv_layout": true,
|
| 161 |
+
"d_model": 256,
|
| 162 |
+
"dim_feedforward": 1024,
|
| 163 |
+
"gating": "none",
|
| 164 |
+
"input_dimension": 384,
|
| 165 |
+
"layer_scale": 0.01,
|
| 166 |
+
"max_period": 10000,
|
| 167 |
+
"module_type": "Transformer",
|
| 168 |
+
"norm": "layer_norm",
|
| 169 |
+
"num_heads": 4,
|
| 170 |
+
"num_layers": 2,
|
| 171 |
+
"output_dimension": 768,
|
| 172 |
+
"positional_embedding": "rope"
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"module_type": "PatchedPretransform",
|
| 176 |
+
"patch_size": 2
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"causal": true,
|
| 180 |
+
"context_duration": 4.0,
|
| 181 |
+
"conv_layout": true,
|
| 182 |
+
"d_model": 256,
|
| 183 |
+
"dim_feedforward": 1024,
|
| 184 |
+
"gating": "none",
|
| 185 |
+
"input_dimension": 384,
|
| 186 |
+
"layer_scale": 0.01,
|
| 187 |
+
"max_period": 10000,
|
| 188 |
+
"module_type": "Transformer",
|
| 189 |
+
"norm": "layer_norm",
|
| 190 |
+
"num_heads": 4,
|
| 191 |
+
"num_layers": 4,
|
| 192 |
+
"output_dimension": 240,
|
| 193 |
+
"positional_embedding": "rope"
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"module_type": "PatchedPretransform",
|
| 197 |
+
"patch_size": 240
|
| 198 |
+
}
|
| 199 |
+
],
|
| 200 |
+
"quantizer_type": "rlfq",
|
| 201 |
+
"quantizer_kwargs": {
|
| 202 |
+
"codebook_dim": 8,
|
| 203 |
+
"codebook_loss_weight": 1.0,
|
| 204 |
+
"codebook_size": 1024,
|
| 205 |
+
"commitment_loss_weight": 0.25,
|
| 206 |
+
"input_dim": 768,
|
| 207 |
+
"num_quantizers": 16,
|
| 208 |
+
"output_dim": 768,
|
| 209 |
+
"quantizer_dropout": 1.0,
|
| 210 |
+
"quantizer_type": "rlfq",
|
| 211 |
+
"rvq_dim": 512
|
| 212 |
+
},
|
| 213 |
+
"transformers_version": "4.56.0.dev0",
|
| 214 |
+
"reversed_decoder_kwargs": [
|
| 215 |
+
{
|
| 216 |
+
"module_type": "PatchedPretransform",
|
| 217 |
+
"patch_size": 240
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"causal": true,
|
| 221 |
+
"context_duration": 4.0,
|
| 222 |
+
"conv_layout": true,
|
| 223 |
+
"d_model": 256,
|
| 224 |
+
"dim_feedforward": 1024,
|
| 225 |
+
"gating": "none",
|
| 226 |
+
"input_dimension": 240,
|
| 227 |
+
"layer_scale": 0.01,
|
| 228 |
+
"max_period": 10000,
|
| 229 |
+
"module_type": "Transformer",
|
| 230 |
+
"norm": "layer_norm",
|
| 231 |
+
"num_heads": 4,
|
| 232 |
+
"num_layers": 4,
|
| 233 |
+
"output_dimension": 384,
|
| 234 |
+
"positional_embedding": "rope"
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"module_type": "PatchedPretransform",
|
| 238 |
+
"patch_size": 2
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"causal": true,
|
| 242 |
+
"context_duration": 6.0,
|
| 243 |
+
"conv_layout": true,
|
| 244 |
+
"d_model": 256,
|
| 245 |
+
"dim_feedforward": 1024,
|
| 246 |
+
"gating": "none",
|
| 247 |
+
"input_dimension": 768,
|
| 248 |
+
"layer_scale": 0.01,
|
| 249 |
+
"max_period": 10000,
|
| 250 |
+
"module_type": "Transformer",
|
| 251 |
+
"norm": "layer_norm",
|
| 252 |
+
"num_heads": 4,
|
| 253 |
+
"num_layers": 2,
|
| 254 |
+
"output_dimension": 384,
|
| 255 |
+
"positional_embedding": "rope"
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"module_type": "PatchedPretransform",
|
| 259 |
+
"patch_size": 2
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"causal": true,
|
| 263 |
+
"context_duration": 8.0,
|
| 264 |
+
"conv_layout": true,
|
| 265 |
+
"d_model": 256,
|
| 266 |
+
"dim_feedforward": 1024,
|
| 267 |
+
"gating": "none",
|
| 268 |
+
"input_dimension": 768,
|
| 269 |
+
"layer_scale": 0.01,
|
| 270 |
+
"max_period": 10000,
|
| 271 |
+
"module_type": "Transformer",
|
| 272 |
+
"norm": "layer_norm",
|
| 273 |
+
"num_heads": 4,
|
| 274 |
+
"num_layers": 2,
|
| 275 |
+
"output_dimension": 384,
|
| 276 |
+
"positional_embedding": "rope"
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"module_type": "PatchedPretransform",
|
| 280 |
+
"patch_size": 2
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"causal": true,
|
| 284 |
+
"context_duration": 10.0,
|
| 285 |
+
"conv_layout": true,
|
| 286 |
+
"d_model": 256,
|
| 287 |
+
"dim_feedforward": 1024,
|
| 288 |
+
"gating": "none",
|
| 289 |
+
"input_dimension": 768,
|
| 290 |
+
"layer_scale": 0.01,
|
| 291 |
+
"max_period": 10000,
|
| 292 |
+
"module_type": "Transformer",
|
| 293 |
+
"norm": "layer_norm",
|
| 294 |
+
"num_heads": 4,
|
| 295 |
+
"num_layers": 4,
|
| 296 |
+
"output_dimension": 192,
|
| 297 |
+
"positional_embedding": "rope"
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"module_type": "PatchedPretransform",
|
| 301 |
+
"patch_size": 4
|
| 302 |
+
}
|
| 303 |
+
]
|
| 304 |
+
}
|
weights/codec/configuration_moss_audio_tokenizer.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""MossAudioTokenizer model configuration."""
|
| 16 |
+
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from transformers.configuration_utils import PreTrainedConfig
|
| 21 |
+
except ImportError:
|
| 22 |
+
from transformers.configuration_utils import PretrainedConfig as PreTrainedConfig
|
| 23 |
+
from transformers.utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MossAudioTokenizerConfig(PreTrainedConfig):
|
| 30 |
+
r"""
|
| 31 |
+
This is the configuration class to store the configuration of a [`MossAudioTokenizerModel`]. It is used to instantiate a
|
| 32 |
+
MossAudioTokenizer model according to the specified arguments, defining the model architecture.
|
| 33 |
+
|
| 34 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
| 35 |
+
[VoiceAgentGroup/moss_audio_tokenizer](https://huggingface.co/VoiceAgentGroup/moss_audio_tokenizer) architecture.
|
| 36 |
+
|
| 37 |
+
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
| 38 |
+
documentation from [`PreTrainedConfig`] for more information.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
sampling_rate (`int`, *optional*, defaults to 48000):
|
| 42 |
+
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
|
| 43 |
+
downsample_rate (`int`, *optional*, defaults to 3840):
|
| 44 |
+
Total downsampling rate from waveform to tokens.
|
| 45 |
+
causal_transformer_context_duration (`float`, *optional*, defaults to 10.0):
|
| 46 |
+
Legacy global fallback context duration in seconds for causal transformer. If an individual transformer
|
| 47 |
+
entry in `encoder_kwargs` or `decoder_kwargs` provides `context_duration`, that per-module value takes
|
| 48 |
+
precedence.
|
| 49 |
+
encoder_kwargs (`list[dict]`, *optional*):
|
| 50 |
+
List of encoder module configurations. Each dict specifies a module type and its parameters.
|
| 51 |
+
decoder_kwargs (`list[dict]`, *optional*):
|
| 52 |
+
List of decoder module configurations in execution order.
|
| 53 |
+
number_channels (`int`, *optional*, defaults to 2):
|
| 54 |
+
Number of audio channels exposed by the public waveform interface.
|
| 55 |
+
enable_channel_interleave (`bool`, *optional*, defaults to `True`):
|
| 56 |
+
Whether to flatten multi-channel waveforms into a single internal stream before codec inference.
|
| 57 |
+
attention_implementation (`str`, *optional*, defaults to `"sdpa"`):
|
| 58 |
+
Attention implementation to prefer for transformer layers. Supported values are `"sdpa"` and
|
| 59 |
+
`"flash_attention_2"`.
|
| 60 |
+
compute_dtype (`str`, *optional*, defaults to `"fp32"`):
|
| 61 |
+
Inference compute dtype for non-quantizer modules. Supported values are `"fp32"`, `"bf16"`, and `"fp16"`.
|
| 62 |
+
quantizer_type (`str`, *optional*, defaults to `"rlfq"`):
|
| 63 |
+
Quantizer type. Options include `"rvq"`, `"spec_rvq"`, `"rlfq"`, `"random_prefix_rlfq"`.
|
| 64 |
+
quantizer_kwargs (`dict`, *optional*):
|
| 65 |
+
Configuration for the quantizer including `input_dim`, `rvq_dim`, `output_dim`, `num_quantizers`,
|
| 66 |
+
`codebook_size`, and `codebook_dim`.
|
| 67 |
+
|
| 68 |
+
Example:
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
>>> from transformers import MossAudioTokenizerModel, MossAudioTokenizerConfig
|
| 72 |
+
|
| 73 |
+
>>> # Initializing a MossAudioTokenizer style configuration
|
| 74 |
+
>>> configuration = MossAudioTokenizerConfig()
|
| 75 |
+
|
| 76 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 77 |
+
>>> model = MossAudioTokenizerModel(configuration)
|
| 78 |
+
|
| 79 |
+
>>> # Accessing the model configuration
|
| 80 |
+
>>> configuration = model.config
|
| 81 |
+
```
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
model_type = "moss-audio-tokenizer"
|
| 85 |
+
|
| 86 |
+
# Backward-compatible alias used by some checkpoints.
|
| 87 |
+
attribute_map = {"sample_rate": "sampling_rate"}
|
| 88 |
+
|
| 89 |
+
sampling_rate: int
|
| 90 |
+
downsample_rate: int
|
| 91 |
+
causal_transformer_context_duration: float
|
| 92 |
+
encoder_kwargs: list[dict[str, Any]]
|
| 93 |
+
decoder_kwargs: list[dict[str, Any]]
|
| 94 |
+
number_channels: int
|
| 95 |
+
enable_channel_interleave: bool
|
| 96 |
+
attention_implementation: str
|
| 97 |
+
compute_dtype: str
|
| 98 |
+
quantizer_type: str
|
| 99 |
+
quantizer_kwargs: dict[str, Any]
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
version: str | None = None,
|
| 104 |
+
sampling_rate: int = 48000,
|
| 105 |
+
downsample_rate: int = 3840,
|
| 106 |
+
causal_transformer_context_duration: float = 10.0,
|
| 107 |
+
encoder_kwargs: list[dict[str, Any]] | None = None,
|
| 108 |
+
decoder_kwargs: list[dict[str, Any]] | None = None,
|
| 109 |
+
number_channels: int = 2,
|
| 110 |
+
enable_channel_interleave: bool = True,
|
| 111 |
+
attention_implementation: str = "sdpa",
|
| 112 |
+
compute_dtype: str = "fp32",
|
| 113 |
+
quantizer_type: str = "rlfq",
|
| 114 |
+
quantizer_kwargs: dict[str, Any] | None = None,
|
| 115 |
+
**kwargs,
|
| 116 |
+
):
|
| 117 |
+
# Some checkpoints might include an incorrect/legacy `model_type` (e.g. "speech_tokenizer").
|
| 118 |
+
# We drop it to avoid overriding the class-level `model_type`.
|
| 119 |
+
kwargs.pop("model_type", None)
|
| 120 |
+
if "channels_numbers" in kwargs:
|
| 121 |
+
number_channels = kwargs.pop("channels_numbers")
|
| 122 |
+
if "enable_channel_interleave" in kwargs:
|
| 123 |
+
enable_channel_interleave = kwargs.pop("enable_channel_interleave")
|
| 124 |
+
if "attention_backend" in kwargs and attention_implementation == "sdpa":
|
| 125 |
+
attention_implementation = kwargs.pop("attention_backend")
|
| 126 |
+
if "codec_compute_dtype" in kwargs and compute_dtype == "fp32":
|
| 127 |
+
compute_dtype = kwargs.pop("codec_compute_dtype")
|
| 128 |
+
reversed_decoder_kwargs = kwargs.pop("reversed_decoder_kwargs", None)
|
| 129 |
+
|
| 130 |
+
# `version` is accepted for compatibility but not used in modeling.
|
| 131 |
+
self.version = version
|
| 132 |
+
self.sampling_rate = sampling_rate
|
| 133 |
+
self.downsample_rate = downsample_rate
|
| 134 |
+
self.causal_transformer_context_duration = causal_transformer_context_duration
|
| 135 |
+
self.number_channels = number_channels
|
| 136 |
+
self.enable_channel_interleave = enable_channel_interleave
|
| 137 |
+
self.attention_implementation = attention_implementation
|
| 138 |
+
self.compute_dtype = compute_dtype
|
| 139 |
+
# Default encoder configuration
|
| 140 |
+
if encoder_kwargs is None:
|
| 141 |
+
encoder_kwargs = [
|
| 142 |
+
{
|
| 143 |
+
"module_type": "PatchedPretransform",
|
| 144 |
+
"patch_size": 240,
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"module_type": "Transformer",
|
| 148 |
+
"input_dimension": 240,
|
| 149 |
+
"output_dimension": 384,
|
| 150 |
+
"d_model": 768,
|
| 151 |
+
"num_heads": 12,
|
| 152 |
+
"num_layers": 12,
|
| 153 |
+
"dim_feedforward": 3072,
|
| 154 |
+
"causal": True,
|
| 155 |
+
"norm": "layer_norm",
|
| 156 |
+
"positional_embedding": "rope",
|
| 157 |
+
"max_period": 10000,
|
| 158 |
+
"gating": "none",
|
| 159 |
+
"layer_scale": 0.01,
|
| 160 |
+
"conv_layout": True,
|
| 161 |
+
"context_duration": 1.0,
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"module_type": "PatchedPretransform",
|
| 165 |
+
"patch_size": 2,
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"module_type": "Transformer",
|
| 169 |
+
"input_dimension": 768,
|
| 170 |
+
"output_dimension": 384,
|
| 171 |
+
"d_model": 768,
|
| 172 |
+
"num_heads": 12,
|
| 173 |
+
"num_layers": 12,
|
| 174 |
+
"dim_feedforward": 3072,
|
| 175 |
+
"causal": True,
|
| 176 |
+
"norm": "layer_norm",
|
| 177 |
+
"positional_embedding": "rope",
|
| 178 |
+
"max_period": 10000,
|
| 179 |
+
"gating": "none",
|
| 180 |
+
"layer_scale": 0.01,
|
| 181 |
+
"conv_layout": True,
|
| 182 |
+
"context_duration": 2.0,
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"module_type": "PatchedPretransform",
|
| 186 |
+
"patch_size": 2,
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"module_type": "Transformer",
|
| 190 |
+
"input_dimension": 768,
|
| 191 |
+
"output_dimension": 384,
|
| 192 |
+
"d_model": 768,
|
| 193 |
+
"num_heads": 12,
|
| 194 |
+
"num_layers": 12,
|
| 195 |
+
"dim_feedforward": 3072,
|
| 196 |
+
"causal": True,
|
| 197 |
+
"norm": "layer_norm",
|
| 198 |
+
"positional_embedding": "rope",
|
| 199 |
+
"max_period": 10000,
|
| 200 |
+
"gating": "none",
|
| 201 |
+
"layer_scale": 0.01,
|
| 202 |
+
"conv_layout": True,
|
| 203 |
+
"context_duration": 4.0,
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"module_type": "PatchedPretransform",
|
| 207 |
+
"patch_size": 2,
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"module_type": "Transformer",
|
| 211 |
+
"input_dimension": 768,
|
| 212 |
+
"output_dimension": 384,
|
| 213 |
+
"d_model": 768,
|
| 214 |
+
"num_heads": 12,
|
| 215 |
+
"num_layers": 12,
|
| 216 |
+
"dim_feedforward": 3072,
|
| 217 |
+
"causal": True,
|
| 218 |
+
"norm": "layer_norm",
|
| 219 |
+
"positional_embedding": "rope",
|
| 220 |
+
"max_period": 10000,
|
| 221 |
+
"gating": "none",
|
| 222 |
+
"layer_scale": 0.01,
|
| 223 |
+
"conv_layout": True,
|
| 224 |
+
"context_duration": 8.0,
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"module_type": "PatchedPretransform",
|
| 228 |
+
"patch_size": 2,
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"module_type": "Transformer",
|
| 232 |
+
"input_dimension": 768,
|
| 233 |
+
"output_dimension": 640,
|
| 234 |
+
"d_model": 768,
|
| 235 |
+
"num_heads": 12,
|
| 236 |
+
"num_layers": 12,
|
| 237 |
+
"dim_feedforward": 3072,
|
| 238 |
+
"causal": True,
|
| 239 |
+
"norm": "layer_norm",
|
| 240 |
+
"positional_embedding": "rope",
|
| 241 |
+
"max_period": 10000,
|
| 242 |
+
"gating": "none",
|
| 243 |
+
"layer_scale": 0.01,
|
| 244 |
+
"conv_layout": True,
|
| 245 |
+
"context_duration": 10.0,
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"module_type": "PatchedPretransform",
|
| 249 |
+
"patch_size": 2,
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"module_type": "Transformer",
|
| 253 |
+
"input_dimension": 1280,
|
| 254 |
+
"output_dimension": 768,
|
| 255 |
+
"d_model": 1280,
|
| 256 |
+
"num_heads": 20,
|
| 257 |
+
"num_layers": 32,
|
| 258 |
+
"dim_feedforward": 5120,
|
| 259 |
+
"causal": True,
|
| 260 |
+
"norm": "layer_norm",
|
| 261 |
+
"positional_embedding": "rope",
|
| 262 |
+
"max_period": 10000,
|
| 263 |
+
"gating": "none",
|
| 264 |
+
"layer_scale": 0.01,
|
| 265 |
+
"conv_layout": True,
|
| 266 |
+
"context_duration": 10.0,
|
| 267 |
+
},
|
| 268 |
+
]
|
| 269 |
+
else:
|
| 270 |
+
encoder_kwargs = [dict(module_kwargs) for module_kwargs in encoder_kwargs]
|
| 271 |
+
for module_kwargs in encoder_kwargs:
|
| 272 |
+
if module_kwargs.get("module_type") == "Transformer":
|
| 273 |
+
module_kwargs.setdefault("context_duration", causal_transformer_context_duration)
|
| 274 |
+
self.encoder_kwargs = encoder_kwargs
|
| 275 |
+
|
| 276 |
+
# Default decoder configuration (execution order)
|
| 277 |
+
if decoder_kwargs is None and reversed_decoder_kwargs is not None:
|
| 278 |
+
reversed_decoder_kwargs = [dict(module_kwargs) for module_kwargs in reversed_decoder_kwargs]
|
| 279 |
+
decoder_kwargs = []
|
| 280 |
+
for module_kwargs in reversed_decoder_kwargs[::-1]:
|
| 281 |
+
if module_kwargs.get("module_type") != "Transformer":
|
| 282 |
+
decoder_kwargs.append(module_kwargs)
|
| 283 |
+
continue
|
| 284 |
+
module_kwargs = dict(module_kwargs)
|
| 285 |
+
module_kwargs["input_dimension"], module_kwargs["output_dimension"] = (
|
| 286 |
+
module_kwargs["output_dimension"],
|
| 287 |
+
module_kwargs["input_dimension"],
|
| 288 |
+
)
|
| 289 |
+
decoder_kwargs.append(module_kwargs)
|
| 290 |
+
|
| 291 |
+
if decoder_kwargs is None:
|
| 292 |
+
decoder_kwargs = [
|
| 293 |
+
{
|
| 294 |
+
"module_type": "Transformer",
|
| 295 |
+
"input_dimension": 768,
|
| 296 |
+
"output_dimension": 1280,
|
| 297 |
+
"d_model": 1280,
|
| 298 |
+
"num_heads": 20,
|
| 299 |
+
"num_layers": 32,
|
| 300 |
+
"dim_feedforward": 5120,
|
| 301 |
+
"causal": True,
|
| 302 |
+
"norm": "layer_norm",
|
| 303 |
+
"positional_embedding": "rope",
|
| 304 |
+
"max_period": 10000,
|
| 305 |
+
"gating": "none",
|
| 306 |
+
"layer_scale": 0.01,
|
| 307 |
+
"conv_layout": True,
|
| 308 |
+
"context_duration": 10.0,
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"module_type": "PatchedPretransform",
|
| 312 |
+
"patch_size": 2,
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"module_type": "Transformer",
|
| 316 |
+
"input_dimension": 640,
|
| 317 |
+
"output_dimension": 768,
|
| 318 |
+
"d_model": 768,
|
| 319 |
+
"num_heads": 12,
|
| 320 |
+
"num_layers": 12,
|
| 321 |
+
"dim_feedforward": 3072,
|
| 322 |
+
"causal": True,
|
| 323 |
+
"norm": "layer_norm",
|
| 324 |
+
"positional_embedding": "rope",
|
| 325 |
+
"max_period": 10000,
|
| 326 |
+
"gating": "none",
|
| 327 |
+
"layer_scale": 0.01,
|
| 328 |
+
"conv_layout": True,
|
| 329 |
+
"context_duration": 10.0,
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"module_type": "PatchedPretransform",
|
| 333 |
+
"patch_size": 2,
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"module_type": "Transformer",
|
| 337 |
+
"input_dimension": 384,
|
| 338 |
+
"output_dimension": 768,
|
| 339 |
+
"d_model": 768,
|
| 340 |
+
"num_heads": 12,
|
| 341 |
+
"num_layers": 12,
|
| 342 |
+
"dim_feedforward": 3072,
|
| 343 |
+
"causal": True,
|
| 344 |
+
"norm": "layer_norm",
|
| 345 |
+
"positional_embedding": "rope",
|
| 346 |
+
"max_period": 10000,
|
| 347 |
+
"gating": "none",
|
| 348 |
+
"layer_scale": 0.01,
|
| 349 |
+
"conv_layout": True,
|
| 350 |
+
"context_duration": 8.0,
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"module_type": "PatchedPretransform",
|
| 354 |
+
"patch_size": 2,
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"module_type": "Transformer",
|
| 358 |
+
"input_dimension": 384,
|
| 359 |
+
"output_dimension": 768,
|
| 360 |
+
"d_model": 768,
|
| 361 |
+
"num_heads": 12,
|
| 362 |
+
"num_layers": 12,
|
| 363 |
+
"dim_feedforward": 3072,
|
| 364 |
+
"causal": True,
|
| 365 |
+
"norm": "layer_norm",
|
| 366 |
+
"positional_embedding": "rope",
|
| 367 |
+
"max_period": 10000,
|
| 368 |
+
"gating": "none",
|
| 369 |
+
"layer_scale": 0.01,
|
| 370 |
+
"conv_layout": True,
|
| 371 |
+
"context_duration": 4.0,
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"module_type": "PatchedPretransform",
|
| 375 |
+
"patch_size": 2,
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
"module_type": "Transformer",
|
| 379 |
+
"input_dimension": 384,
|
| 380 |
+
"output_dimension": 768,
|
| 381 |
+
"d_model": 768,
|
| 382 |
+
"num_heads": 12,
|
| 383 |
+
"num_layers": 12,
|
| 384 |
+
"dim_feedforward": 3072,
|
| 385 |
+
"causal": True,
|
| 386 |
+
"norm": "layer_norm",
|
| 387 |
+
"positional_embedding": "rope",
|
| 388 |
+
"max_period": 10000,
|
| 389 |
+
"gating": "none",
|
| 390 |
+
"layer_scale": 0.01,
|
| 391 |
+
"conv_layout": True,
|
| 392 |
+
"context_duration": 2.0,
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"module_type": "PatchedPretransform",
|
| 396 |
+
"patch_size": 2,
|
| 397 |
+
},
|
| 398 |
+
{
|
| 399 |
+
"module_type": "Transformer",
|
| 400 |
+
"input_dimension": 384,
|
| 401 |
+
"output_dimension": 240,
|
| 402 |
+
"d_model": 768,
|
| 403 |
+
"num_heads": 12,
|
| 404 |
+
"num_layers": 12,
|
| 405 |
+
"dim_feedforward": 3072,
|
| 406 |
+
"causal": True,
|
| 407 |
+
"norm": "layer_norm",
|
| 408 |
+
"positional_embedding": "rope",
|
| 409 |
+
"max_period": 10000,
|
| 410 |
+
"gating": "none",
|
| 411 |
+
"layer_scale": 0.01,
|
| 412 |
+
"conv_layout": True,
|
| 413 |
+
"context_duration": 1.0,
|
| 414 |
+
},
|
| 415 |
+
{
|
| 416 |
+
"module_type": "PatchedPretransform",
|
| 417 |
+
"patch_size": 240,
|
| 418 |
+
},
|
| 419 |
+
]
|
| 420 |
+
else:
|
| 421 |
+
decoder_kwargs = [dict(module_kwargs) for module_kwargs in decoder_kwargs]
|
| 422 |
+
for module_kwargs in decoder_kwargs:
|
| 423 |
+
if module_kwargs.get("module_type") == "Transformer":
|
| 424 |
+
module_kwargs.setdefault("context_duration", causal_transformer_context_duration)
|
| 425 |
+
self.decoder_kwargs = decoder_kwargs
|
| 426 |
+
|
| 427 |
+
# Default quantizer configuration
|
| 428 |
+
if quantizer_kwargs is None:
|
| 429 |
+
quantizer_kwargs = {
|
| 430 |
+
"input_dim": 768,
|
| 431 |
+
"rvq_dim": 512,
|
| 432 |
+
"output_dim": 768,
|
| 433 |
+
"num_quantizers": 32,
|
| 434 |
+
"codebook_size": 1024,
|
| 435 |
+
"codebook_dim": 8,
|
| 436 |
+
"quantizer_type": "rlfq",
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
# Handle quantizer_type from kwargs or config
|
| 440 |
+
kw_qtype = quantizer_kwargs.get("quantizer_type", None)
|
| 441 |
+
if kw_qtype is not None:
|
| 442 |
+
self.quantizer_type = kw_qtype
|
| 443 |
+
else:
|
| 444 |
+
self.quantizer_type = quantizer_type
|
| 445 |
+
quantizer_kwargs["quantizer_type"] = quantizer_type
|
| 446 |
+
|
| 447 |
+
self.quantizer_kwargs = quantizer_kwargs
|
| 448 |
+
|
| 449 |
+
super().__init__(**kwargs)
|
| 450 |
+
|
| 451 |
+
@property
|
| 452 |
+
def num_quantizers(self) -> int:
|
| 453 |
+
"""Return the number of quantizers from quantizer_kwargs."""
|
| 454 |
+
return self.quantizer_kwargs.get("num_quantizers", 32)
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def codebook_size(self) -> int:
|
| 458 |
+
"""Return the codebook size from quantizer_kwargs."""
|
| 459 |
+
return self.quantizer_kwargs.get("codebook_size", 4096)
|
| 460 |
+
|
| 461 |
+
@property
|
| 462 |
+
def frame_rate(self) -> float:
|
| 463 |
+
"""Return the frame rate (tokens per second)."""
|
| 464 |
+
return self.sampling_rate / self.downsample_rate
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
__all__ = ["MossAudioTokenizerConfig"]
|
weights/codec/model-00001-of-00001.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34d9880d805eecb21bde975202b1c256dbd0eb98c8680b9d3aeffd2bc6ac2f67
|
| 3 |
+
size 87922568
|
weights/codec/model.safetensors.index.json
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_parameters": 21969664,
|
| 4 |
+
"total_size": 87878656
|
| 5 |
+
},
|
| 6 |
+
"weight_map": {
|
| 7 |
+
"encoder.1.input_proj.weight": "model-00001-of-00001.safetensors",
|
| 8 |
+
"encoder.1.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
|
| 9 |
+
"encoder.1.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
|
| 10 |
+
"encoder.1.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 11 |
+
"encoder.1.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 12 |
+
"encoder.1.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
|
| 13 |
+
"encoder.1.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
|
| 14 |
+
"encoder.1.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 15 |
+
"encoder.1.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 16 |
+
"encoder.1.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 17 |
+
"encoder.1.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 18 |
+
"encoder.1.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
|
| 19 |
+
"encoder.1.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
|
| 20 |
+
"encoder.1.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 21 |
+
"encoder.1.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 22 |
+
"encoder.1.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
|
| 23 |
+
"encoder.1.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
|
| 24 |
+
"encoder.1.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 25 |
+
"encoder.1.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 26 |
+
"encoder.1.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 27 |
+
"encoder.1.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 28 |
+
"encoder.1.transformer.layers.2.norm1.weight": "model-00001-of-00001.safetensors",
|
| 29 |
+
"encoder.1.transformer.layers.2.norm1.bias": "model-00001-of-00001.safetensors",
|
| 30 |
+
"encoder.1.transformer.layers.2.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 31 |
+
"encoder.1.transformer.layers.2.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 32 |
+
"encoder.1.transformer.layers.2.norm2.weight": "model-00001-of-00001.safetensors",
|
| 33 |
+
"encoder.1.transformer.layers.2.norm2.bias": "model-00001-of-00001.safetensors",
|
| 34 |
+
"encoder.1.transformer.layers.2.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 35 |
+
"encoder.1.transformer.layers.2.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 36 |
+
"encoder.1.transformer.layers.2.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 37 |
+
"encoder.1.transformer.layers.2.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 38 |
+
"encoder.1.transformer.layers.3.norm1.weight": "model-00001-of-00001.safetensors",
|
| 39 |
+
"encoder.1.transformer.layers.3.norm1.bias": "model-00001-of-00001.safetensors",
|
| 40 |
+
"encoder.1.transformer.layers.3.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 41 |
+
"encoder.1.transformer.layers.3.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 42 |
+
"encoder.1.transformer.layers.3.norm2.weight": "model-00001-of-00001.safetensors",
|
| 43 |
+
"encoder.1.transformer.layers.3.norm2.bias": "model-00001-of-00001.safetensors",
|
| 44 |
+
"encoder.1.transformer.layers.3.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 45 |
+
"encoder.1.transformer.layers.3.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 46 |
+
"encoder.1.transformer.layers.3.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 47 |
+
"encoder.1.transformer.layers.3.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 48 |
+
"encoder.1.output_proj.weight": "model-00001-of-00001.safetensors",
|
| 49 |
+
"encoder.3.input_proj.weight": "model-00001-of-00001.safetensors",
|
| 50 |
+
"encoder.3.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
|
| 51 |
+
"encoder.3.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
|
| 52 |
+
"encoder.3.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 53 |
+
"encoder.3.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 54 |
+
"encoder.3.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
|
| 55 |
+
"encoder.3.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
|
| 56 |
+
"encoder.3.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 57 |
+
"encoder.3.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 58 |
+
"encoder.3.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 59 |
+
"encoder.3.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 60 |
+
"encoder.3.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
|
| 61 |
+
"encoder.3.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
|
| 62 |
+
"encoder.3.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 63 |
+
"encoder.3.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 64 |
+
"encoder.3.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
|
| 65 |
+
"encoder.3.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
|
| 66 |
+
"encoder.3.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 67 |
+
"encoder.3.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 68 |
+
"encoder.3.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 69 |
+
"encoder.3.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 70 |
+
"encoder.3.output_proj.weight": "model-00001-of-00001.safetensors",
|
| 71 |
+
"encoder.5.input_proj.weight": "model-00001-of-00001.safetensors",
|
| 72 |
+
"encoder.5.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
|
| 73 |
+
"encoder.5.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
|
| 74 |
+
"encoder.5.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 75 |
+
"encoder.5.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 76 |
+
"encoder.5.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
|
| 77 |
+
"encoder.5.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
|
| 78 |
+
"encoder.5.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 79 |
+
"encoder.5.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 80 |
+
"encoder.5.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 81 |
+
"encoder.5.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 82 |
+
"encoder.5.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
|
| 83 |
+
"encoder.5.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
|
| 84 |
+
"encoder.5.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 85 |
+
"encoder.5.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 86 |
+
"encoder.5.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
|
| 87 |
+
"encoder.5.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
|
| 88 |
+
"encoder.5.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 89 |
+
"encoder.5.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 90 |
+
"encoder.5.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 91 |
+
"encoder.5.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 92 |
+
"encoder.5.output_proj.weight": "model-00001-of-00001.safetensors",
|
| 93 |
+
"encoder.7.input_proj.weight": "model-00001-of-00001.safetensors",
|
| 94 |
+
"encoder.7.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
|
| 95 |
+
"encoder.7.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
|
| 96 |
+
"encoder.7.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 97 |
+
"encoder.7.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 98 |
+
"encoder.7.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
|
| 99 |
+
"encoder.7.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
|
| 100 |
+
"encoder.7.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 101 |
+
"encoder.7.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 102 |
+
"encoder.7.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 103 |
+
"encoder.7.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 104 |
+
"encoder.7.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
|
| 105 |
+
"encoder.7.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
|
| 106 |
+
"encoder.7.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 107 |
+
"encoder.7.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 108 |
+
"encoder.7.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
|
| 109 |
+
"encoder.7.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
|
| 110 |
+
"encoder.7.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 111 |
+
"encoder.7.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 112 |
+
"encoder.7.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 113 |
+
"encoder.7.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 114 |
+
"encoder.7.transformer.layers.2.norm1.weight": "model-00001-of-00001.safetensors",
|
| 115 |
+
"encoder.7.transformer.layers.2.norm1.bias": "model-00001-of-00001.safetensors",
|
| 116 |
+
"encoder.7.transformer.layers.2.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 117 |
+
"encoder.7.transformer.layers.2.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 118 |
+
"encoder.7.transformer.layers.2.norm2.weight": "model-00001-of-00001.safetensors",
|
| 119 |
+
"encoder.7.transformer.layers.2.norm2.bias": "model-00001-of-00001.safetensors",
|
| 120 |
+
"encoder.7.transformer.layers.2.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 121 |
+
"encoder.7.transformer.layers.2.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 122 |
+
"encoder.7.transformer.layers.2.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 123 |
+
"encoder.7.transformer.layers.2.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 124 |
+
"encoder.7.transformer.layers.3.norm1.weight": "model-00001-of-00001.safetensors",
|
| 125 |
+
"encoder.7.transformer.layers.3.norm1.bias": "model-00001-of-00001.safetensors",
|
| 126 |
+
"encoder.7.transformer.layers.3.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 127 |
+
"encoder.7.transformer.layers.3.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 128 |
+
"encoder.7.transformer.layers.3.norm2.weight": "model-00001-of-00001.safetensors",
|
| 129 |
+
"encoder.7.transformer.layers.3.norm2.bias": "model-00001-of-00001.safetensors",
|
| 130 |
+
"encoder.7.transformer.layers.3.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 131 |
+
"encoder.7.transformer.layers.3.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 132 |
+
"encoder.7.transformer.layers.3.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 133 |
+
"encoder.7.transformer.layers.3.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 134 |
+
"encoder.7.output_proj.weight": "model-00001-of-00001.safetensors",
|
| 135 |
+
"quantizer.input_proj.bias": "model-00001-of-00001.safetensors",
|
| 136 |
+
"quantizer.input_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 137 |
+
"quantizer.input_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 138 |
+
"quantizer.output_proj.bias": "model-00001-of-00001.safetensors",
|
| 139 |
+
"quantizer.output_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 140 |
+
"quantizer.output_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 141 |
+
"quantizer.quantizers.0.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 142 |
+
"quantizer.quantizers.0.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 143 |
+
"quantizer.quantizers.0.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 144 |
+
"quantizer.quantizers.0.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 145 |
+
"quantizer.quantizers.0.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 146 |
+
"quantizer.quantizers.0.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 147 |
+
"quantizer.quantizers.0.codebook.weight": "model-00001-of-00001.safetensors",
|
| 148 |
+
"quantizer.quantizers.1.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 149 |
+
"quantizer.quantizers.1.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 150 |
+
"quantizer.quantizers.1.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 151 |
+
"quantizer.quantizers.1.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 152 |
+
"quantizer.quantizers.1.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 153 |
+
"quantizer.quantizers.1.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 154 |
+
"quantizer.quantizers.1.codebook.weight": "model-00001-of-00001.safetensors",
|
| 155 |
+
"quantizer.quantizers.2.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 156 |
+
"quantizer.quantizers.2.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 157 |
+
"quantizer.quantizers.2.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 158 |
+
"quantizer.quantizers.2.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 159 |
+
"quantizer.quantizers.2.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 160 |
+
"quantizer.quantizers.2.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 161 |
+
"quantizer.quantizers.2.codebook.weight": "model-00001-of-00001.safetensors",
|
| 162 |
+
"quantizer.quantizers.3.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 163 |
+
"quantizer.quantizers.3.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 164 |
+
"quantizer.quantizers.3.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 165 |
+
"quantizer.quantizers.3.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 166 |
+
"quantizer.quantizers.3.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 167 |
+
"quantizer.quantizers.3.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 168 |
+
"quantizer.quantizers.3.codebook.weight": "model-00001-of-00001.safetensors",
|
| 169 |
+
"quantizer.quantizers.4.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 170 |
+
"quantizer.quantizers.4.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 171 |
+
"quantizer.quantizers.4.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 172 |
+
"quantizer.quantizers.4.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 173 |
+
"quantizer.quantizers.4.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 174 |
+
"quantizer.quantizers.4.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 175 |
+
"quantizer.quantizers.4.codebook.weight": "model-00001-of-00001.safetensors",
|
| 176 |
+
"quantizer.quantizers.5.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 177 |
+
"quantizer.quantizers.5.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 178 |
+
"quantizer.quantizers.5.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 179 |
+
"quantizer.quantizers.5.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 180 |
+
"quantizer.quantizers.5.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 181 |
+
"quantizer.quantizers.5.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 182 |
+
"quantizer.quantizers.5.codebook.weight": "model-00001-of-00001.safetensors",
|
| 183 |
+
"quantizer.quantizers.6.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 184 |
+
"quantizer.quantizers.6.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 185 |
+
"quantizer.quantizers.6.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 186 |
+
"quantizer.quantizers.6.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 187 |
+
"quantizer.quantizers.6.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 188 |
+
"quantizer.quantizers.6.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 189 |
+
"quantizer.quantizers.6.codebook.weight": "model-00001-of-00001.safetensors",
|
| 190 |
+
"quantizer.quantizers.7.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 191 |
+
"quantizer.quantizers.7.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 192 |
+
"quantizer.quantizers.7.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 193 |
+
"quantizer.quantizers.7.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 194 |
+
"quantizer.quantizers.7.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 195 |
+
"quantizer.quantizers.7.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 196 |
+
"quantizer.quantizers.7.codebook.weight": "model-00001-of-00001.safetensors",
|
| 197 |
+
"quantizer.quantizers.8.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 198 |
+
"quantizer.quantizers.8.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 199 |
+
"quantizer.quantizers.8.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 200 |
+
"quantizer.quantizers.8.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 201 |
+
"quantizer.quantizers.8.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 202 |
+
"quantizer.quantizers.8.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 203 |
+
"quantizer.quantizers.8.codebook.weight": "model-00001-of-00001.safetensors",
|
| 204 |
+
"quantizer.quantizers.9.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 205 |
+
"quantizer.quantizers.9.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 206 |
+
"quantizer.quantizers.9.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 207 |
+
"quantizer.quantizers.9.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 208 |
+
"quantizer.quantizers.9.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 209 |
+
"quantizer.quantizers.9.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 210 |
+
"quantizer.quantizers.9.codebook.weight": "model-00001-of-00001.safetensors",
|
| 211 |
+
"quantizer.quantizers.10.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 212 |
+
"quantizer.quantizers.10.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 213 |
+
"quantizer.quantizers.10.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 214 |
+
"quantizer.quantizers.10.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 215 |
+
"quantizer.quantizers.10.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 216 |
+
"quantizer.quantizers.10.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 217 |
+
"quantizer.quantizers.10.codebook.weight": "model-00001-of-00001.safetensors",
|
| 218 |
+
"quantizer.quantizers.11.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 219 |
+
"quantizer.quantizers.11.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 220 |
+
"quantizer.quantizers.11.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 221 |
+
"quantizer.quantizers.11.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 222 |
+
"quantizer.quantizers.11.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 223 |
+
"quantizer.quantizers.11.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 224 |
+
"quantizer.quantizers.11.codebook.weight": "model-00001-of-00001.safetensors",
|
| 225 |
+
"quantizer.quantizers.12.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 226 |
+
"quantizer.quantizers.12.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 227 |
+
"quantizer.quantizers.12.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 228 |
+
"quantizer.quantizers.12.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 229 |
+
"quantizer.quantizers.12.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 230 |
+
"quantizer.quantizers.12.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 231 |
+
"quantizer.quantizers.12.codebook.weight": "model-00001-of-00001.safetensors",
|
| 232 |
+
"quantizer.quantizers.13.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 233 |
+
"quantizer.quantizers.13.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 234 |
+
"quantizer.quantizers.13.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 235 |
+
"quantizer.quantizers.13.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 236 |
+
"quantizer.quantizers.13.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 237 |
+
"quantizer.quantizers.13.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 238 |
+
"quantizer.quantizers.13.codebook.weight": "model-00001-of-00001.safetensors",
|
| 239 |
+
"quantizer.quantizers.14.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 240 |
+
"quantizer.quantizers.14.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 241 |
+
"quantizer.quantizers.14.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 242 |
+
"quantizer.quantizers.14.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 243 |
+
"quantizer.quantizers.14.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 244 |
+
"quantizer.quantizers.14.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 245 |
+
"quantizer.quantizers.14.codebook.weight": "model-00001-of-00001.safetensors",
|
| 246 |
+
"quantizer.quantizers.15.in_proj.bias": "model-00001-of-00001.safetensors",
|
| 247 |
+
"quantizer.quantizers.15.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 248 |
+
"quantizer.quantizers.15.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 249 |
+
"quantizer.quantizers.15.out_proj.bias": "model-00001-of-00001.safetensors",
|
| 250 |
+
"quantizer.quantizers.15.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
|
| 251 |
+
"quantizer.quantizers.15.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
|
| 252 |
+
"quantizer.quantizers.15.codebook.weight": "model-00001-of-00001.safetensors",
|
| 253 |
+
"decoder.1.input_proj.weight": "model-00001-of-00001.safetensors",
|
| 254 |
+
"decoder.1.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
|
| 255 |
+
"decoder.1.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
|
| 256 |
+
"decoder.1.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 257 |
+
"decoder.1.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 258 |
+
"decoder.1.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
|
| 259 |
+
"decoder.1.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
|
| 260 |
+
"decoder.1.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 261 |
+
"decoder.1.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 262 |
+
"decoder.1.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 263 |
+
"decoder.1.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 264 |
+
"decoder.1.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
|
| 265 |
+
"decoder.1.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
|
| 266 |
+
"decoder.1.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 267 |
+
"decoder.1.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 268 |
+
"decoder.1.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
|
| 269 |
+
"decoder.1.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
|
| 270 |
+
"decoder.1.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 271 |
+
"decoder.1.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 272 |
+
"decoder.1.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 273 |
+
"decoder.1.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 274 |
+
"decoder.1.transformer.layers.2.norm1.weight": "model-00001-of-00001.safetensors",
|
| 275 |
+
"decoder.1.transformer.layers.2.norm1.bias": "model-00001-of-00001.safetensors",
|
| 276 |
+
"decoder.1.transformer.layers.2.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 277 |
+
"decoder.1.transformer.layers.2.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 278 |
+
"decoder.1.transformer.layers.2.norm2.weight": "model-00001-of-00001.safetensors",
|
| 279 |
+
"decoder.1.transformer.layers.2.norm2.bias": "model-00001-of-00001.safetensors",
|
| 280 |
+
"decoder.1.transformer.layers.2.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 281 |
+
"decoder.1.transformer.layers.2.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 282 |
+
"decoder.1.transformer.layers.2.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 283 |
+
"decoder.1.transformer.layers.2.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 284 |
+
"decoder.1.transformer.layers.3.norm1.weight": "model-00001-of-00001.safetensors",
|
| 285 |
+
"decoder.1.transformer.layers.3.norm1.bias": "model-00001-of-00001.safetensors",
|
| 286 |
+
"decoder.1.transformer.layers.3.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 287 |
+
"decoder.1.transformer.layers.3.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 288 |
+
"decoder.1.transformer.layers.3.norm2.weight": "model-00001-of-00001.safetensors",
|
| 289 |
+
"decoder.1.transformer.layers.3.norm2.bias": "model-00001-of-00001.safetensors",
|
| 290 |
+
"decoder.1.transformer.layers.3.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 291 |
+
"decoder.1.transformer.layers.3.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 292 |
+
"decoder.1.transformer.layers.3.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 293 |
+
"decoder.1.transformer.layers.3.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 294 |
+
"decoder.1.output_proj.weight": "model-00001-of-00001.safetensors",
|
| 295 |
+
"decoder.3.input_proj.weight": "model-00001-of-00001.safetensors",
|
| 296 |
+
"decoder.3.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
|
| 297 |
+
"decoder.3.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
|
| 298 |
+
"decoder.3.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 299 |
+
"decoder.3.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 300 |
+
"decoder.3.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
|
| 301 |
+
"decoder.3.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
|
| 302 |
+
"decoder.3.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 303 |
+
"decoder.3.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 304 |
+
"decoder.3.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 305 |
+
"decoder.3.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 306 |
+
"decoder.3.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
|
| 307 |
+
"decoder.3.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
|
| 308 |
+
"decoder.3.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 309 |
+
"decoder.3.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 310 |
+
"decoder.3.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
|
| 311 |
+
"decoder.3.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
|
| 312 |
+
"decoder.3.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 313 |
+
"decoder.3.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 314 |
+
"decoder.3.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 315 |
+
"decoder.3.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 316 |
+
"decoder.3.output_proj.weight": "model-00001-of-00001.safetensors",
|
| 317 |
+
"decoder.5.input_proj.weight": "model-00001-of-00001.safetensors",
|
| 318 |
+
"decoder.5.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
|
| 319 |
+
"decoder.5.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
|
| 320 |
+
"decoder.5.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 321 |
+
"decoder.5.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 322 |
+
"decoder.5.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
|
| 323 |
+
"decoder.5.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
|
| 324 |
+
"decoder.5.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 325 |
+
"decoder.5.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 326 |
+
"decoder.5.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 327 |
+
"decoder.5.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 328 |
+
"decoder.5.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
|
| 329 |
+
"decoder.5.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
|
| 330 |
+
"decoder.5.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 331 |
+
"decoder.5.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 332 |
+
"decoder.5.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
|
| 333 |
+
"decoder.5.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
|
| 334 |
+
"decoder.5.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 335 |
+
"decoder.5.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 336 |
+
"decoder.5.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 337 |
+
"decoder.5.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 338 |
+
"decoder.5.output_proj.weight": "model-00001-of-00001.safetensors",
|
| 339 |
+
"decoder.7.input_proj.weight": "model-00001-of-00001.safetensors",
|
| 340 |
+
"decoder.7.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
|
| 341 |
+
"decoder.7.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
|
| 342 |
+
"decoder.7.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 343 |
+
"decoder.7.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 344 |
+
"decoder.7.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
|
| 345 |
+
"decoder.7.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
|
| 346 |
+
"decoder.7.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 347 |
+
"decoder.7.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 348 |
+
"decoder.7.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 349 |
+
"decoder.7.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 350 |
+
"decoder.7.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
|
| 351 |
+
"decoder.7.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
|
| 352 |
+
"decoder.7.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 353 |
+
"decoder.7.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 354 |
+
"decoder.7.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
|
| 355 |
+
"decoder.7.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
|
| 356 |
+
"decoder.7.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 357 |
+
"decoder.7.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 358 |
+
"decoder.7.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 359 |
+
"decoder.7.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 360 |
+
"decoder.7.transformer.layers.2.norm1.weight": "model-00001-of-00001.safetensors",
|
| 361 |
+
"decoder.7.transformer.layers.2.norm1.bias": "model-00001-of-00001.safetensors",
|
| 362 |
+
"decoder.7.transformer.layers.2.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 363 |
+
"decoder.7.transformer.layers.2.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 364 |
+
"decoder.7.transformer.layers.2.norm2.weight": "model-00001-of-00001.safetensors",
|
| 365 |
+
"decoder.7.transformer.layers.2.norm2.bias": "model-00001-of-00001.safetensors",
|
| 366 |
+
"decoder.7.transformer.layers.2.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 367 |
+
"decoder.7.transformer.layers.2.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 368 |
+
"decoder.7.transformer.layers.2.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 369 |
+
"decoder.7.transformer.layers.2.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 370 |
+
"decoder.7.transformer.layers.3.norm1.weight": "model-00001-of-00001.safetensors",
|
| 371 |
+
"decoder.7.transformer.layers.3.norm1.bias": "model-00001-of-00001.safetensors",
|
| 372 |
+
"decoder.7.transformer.layers.3.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
|
| 373 |
+
"decoder.7.transformer.layers.3.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
|
| 374 |
+
"decoder.7.transformer.layers.3.norm2.weight": "model-00001-of-00001.safetensors",
|
| 375 |
+
"decoder.7.transformer.layers.3.norm2.bias": "model-00001-of-00001.safetensors",
|
| 376 |
+
"decoder.7.transformer.layers.3.ffn.0.weight": "model-00001-of-00001.safetensors",
|
| 377 |
+
"decoder.7.transformer.layers.3.ffn.2.weight": "model-00001-of-00001.safetensors",
|
| 378 |
+
"decoder.7.transformer.layers.3.layer_scale_1.scale": "model-00001-of-00001.safetensors",
|
| 379 |
+
"decoder.7.transformer.layers.3.layer_scale_2.scale": "model-00001-of-00001.safetensors",
|
| 380 |
+
"decoder.7.output_proj.weight": "model-00001-of-00001.safetensors"
|
| 381 |
+
}
|
| 382 |
+
}
|
weights/codec/modeling_moss_audio_tokenizer.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
weights/tts/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
weights/tts/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
weights/tts/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration_nanotts import NanoTTSConfig
|
| 2 |
+
from .modeling_nanotts_global_local import (
|
| 3 |
+
NanoTTSGenerationOutput,
|
| 4 |
+
NanoTTSGlobalLocalForCausalLM,
|
| 5 |
+
NanoTTSOutput,
|
| 6 |
+
)
|
| 7 |
+
from .tokenization_nanotts_sentencepiece import NanoTTSSentencePieceTokenizer
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
NanoTTSConfig.register_for_auto_class()
|
| 11 |
+
except Exception:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
for auto_class_name in ("AutoModel", "AutoModelForCausalLM"):
|
| 15 |
+
try:
|
| 16 |
+
NanoTTSGlobalLocalForCausalLM.register_for_auto_class(auto_class_name)
|
| 17 |
+
except Exception:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
NanoTTSSentencePieceTokenizer.register_for_auto_class("AutoTokenizer")
|
| 22 |
+
except Exception:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"NanoTTSConfig",
|
| 27 |
+
"NanoTTSGlobalLocalForCausalLM",
|
| 28 |
+
"NanoTTSSentencePieceTokenizer",
|
| 29 |
+
"NanoTTSGenerationOutput",
|
| 30 |
+
"NanoTTSOutput",
|
| 31 |
+
]
|
weights/tts/config.json
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_cross_attention": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"NanoTTSGlobalLocalForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"attn_implementation": "sdpa",
|
| 7 |
+
"audio_assistant_slot_token_id": 9,
|
| 8 |
+
"audio_codebook_sizes": [
|
| 9 |
+
1024,
|
| 10 |
+
1024,
|
| 11 |
+
1024,
|
| 12 |
+
1024,
|
| 13 |
+
1024,
|
| 14 |
+
1024,
|
| 15 |
+
1024,
|
| 16 |
+
1024,
|
| 17 |
+
1024,
|
| 18 |
+
1024,
|
| 19 |
+
1024,
|
| 20 |
+
1024,
|
| 21 |
+
1024,
|
| 22 |
+
1024,
|
| 23 |
+
1024,
|
| 24 |
+
1024
|
| 25 |
+
],
|
| 26 |
+
"audio_end_token_id": 7,
|
| 27 |
+
"audio_pad_token_id": 1024,
|
| 28 |
+
"audio_start_token_id": 6,
|
| 29 |
+
"audio_tokenizer_pretrained_name_or_path": "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano",
|
| 30 |
+
"audio_tokenizer_sample_rate": 48000,
|
| 31 |
+
"audio_tokenizer_type": "moss-audio-tokenizer-nano",
|
| 32 |
+
"audio_user_slot_token_id": 8,
|
| 33 |
+
"audio_vocab_size": 1024,
|
| 34 |
+
"bad_words_ids": null,
|
| 35 |
+
"begin_suppress_tokens": null,
|
| 36 |
+
"bos_token_id": null,
|
| 37 |
+
"chunk_size_feed_forward": 0,
|
| 38 |
+
"cross_attention_hidden_size": null,
|
| 39 |
+
"decoder_start_token_id": null,
|
| 40 |
+
"diversity_penalty": 0.0,
|
| 41 |
+
"do_sample": false,
|
| 42 |
+
"dtype": "float32",
|
| 43 |
+
"early_stopping": false,
|
| 44 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 45 |
+
"eos_token_id": null,
|
| 46 |
+
"exponential_decay_length_penalty": null,
|
| 47 |
+
"finetuning_task": null,
|
| 48 |
+
"forced_bos_token_id": null,
|
| 49 |
+
"forced_eos_token_id": null,
|
| 50 |
+
"gpt2_config": {
|
| 51 |
+
"_name_or_path": "",
|
| 52 |
+
"activation_function": "gelu_new",
|
| 53 |
+
"add_cross_attention": false,
|
| 54 |
+
"architectures": null,
|
| 55 |
+
"attn_pdrop": 0.0,
|
| 56 |
+
"bad_words_ids": null,
|
| 57 |
+
"begin_suppress_tokens": null,
|
| 58 |
+
"bos_token_id": 1,
|
| 59 |
+
"chunk_size_feed_forward": 0,
|
| 60 |
+
"cross_attention_hidden_size": null,
|
| 61 |
+
"decoder_start_token_id": null,
|
| 62 |
+
"diversity_penalty": 0.0,
|
| 63 |
+
"do_sample": false,
|
| 64 |
+
"dtype": null,
|
| 65 |
+
"early_stopping": false,
|
| 66 |
+
"embd_pdrop": 0.0,
|
| 67 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 68 |
+
"eos_token_id": 2,
|
| 69 |
+
"exponential_decay_length_penalty": null,
|
| 70 |
+
"finetuning_task": null,
|
| 71 |
+
"forced_bos_token_id": null,
|
| 72 |
+
"forced_eos_token_id": null,
|
| 73 |
+
"id2label": {
|
| 74 |
+
"0": "LABEL_0",
|
| 75 |
+
"1": "LABEL_1"
|
| 76 |
+
},
|
| 77 |
+
"initializer_range": 0.02,
|
| 78 |
+
"is_decoder": false,
|
| 79 |
+
"is_encoder_decoder": false,
|
| 80 |
+
"label2id": {
|
| 81 |
+
"LABEL_0": 0,
|
| 82 |
+
"LABEL_1": 1
|
| 83 |
+
},
|
| 84 |
+
"layer_norm_epsilon": 1e-05,
|
| 85 |
+
"length_penalty": 1.0,
|
| 86 |
+
"max_length": 20,
|
| 87 |
+
"min_length": 0,
|
| 88 |
+
"model_type": "gpt2",
|
| 89 |
+
"n_ctx": 32768,
|
| 90 |
+
"n_embd": 768,
|
| 91 |
+
"n_head": 12,
|
| 92 |
+
"n_inner": 3072,
|
| 93 |
+
"n_layer": 12,
|
| 94 |
+
"n_positions": 32768,
|
| 95 |
+
"no_repeat_ngram_size": 0,
|
| 96 |
+
"num_beam_groups": 1,
|
| 97 |
+
"num_beams": 1,
|
| 98 |
+
"num_return_sequences": 1,
|
| 99 |
+
"output_attentions": false,
|
| 100 |
+
"output_hidden_states": false,
|
| 101 |
+
"output_scores": false,
|
| 102 |
+
"pad_token_id": 3,
|
| 103 |
+
"position_embedding_type": "rope",
|
| 104 |
+
"prefix": null,
|
| 105 |
+
"problem_type": null,
|
| 106 |
+
"pruned_heads": {},
|
| 107 |
+
"remove_invalid_values": false,
|
| 108 |
+
"reorder_and_upcast_attn": false,
|
| 109 |
+
"repetition_penalty": 1.0,
|
| 110 |
+
"resid_pdrop": 0.0,
|
| 111 |
+
"return_dict": true,
|
| 112 |
+
"return_dict_in_generate": false,
|
| 113 |
+
"rope_base": 10000.0,
|
| 114 |
+
"scale_attn_by_inverse_layer_idx": false,
|
| 115 |
+
"scale_attn_weights": true,
|
| 116 |
+
"sep_token_id": null,
|
| 117 |
+
"summary_activation": null,
|
| 118 |
+
"summary_first_dropout": 0.1,
|
| 119 |
+
"summary_proj_to_labels": true,
|
| 120 |
+
"summary_type": "cls_index",
|
| 121 |
+
"summary_use_proj": true,
|
| 122 |
+
"suppress_tokens": null,
|
| 123 |
+
"task_specific_params": null,
|
| 124 |
+
"temperature": 1.0,
|
| 125 |
+
"tf_legacy_loss": false,
|
| 126 |
+
"tie_encoder_decoder": false,
|
| 127 |
+
"tie_word_embeddings": true,
|
| 128 |
+
"tokenizer_class": null,
|
| 129 |
+
"top_k": 50,
|
| 130 |
+
"top_p": 1.0,
|
| 131 |
+
"torchscript": false,
|
| 132 |
+
"transformers_version": "4.57.1",
|
| 133 |
+
"typical_p": 1.0,
|
| 134 |
+
"use_bfloat16": false,
|
| 135 |
+
"use_cache": true,
|
| 136 |
+
"vocab_size": 16384
|
| 137 |
+
},
|
| 138 |
+
"hidden_size": 768,
|
| 139 |
+
"id2label": {
|
| 140 |
+
"0": "LABEL_0",
|
| 141 |
+
"1": "LABEL_1"
|
| 142 |
+
},
|
| 143 |
+
"im_end_token_id": 5,
|
| 144 |
+
"im_start_token_id": 4,
|
| 145 |
+
"initializer_range": 0.02,
|
| 146 |
+
"is_decoder": false,
|
| 147 |
+
"is_encoder_decoder": false,
|
| 148 |
+
"label2id": {
|
| 149 |
+
"LABEL_0": 0,
|
| 150 |
+
"LABEL_1": 1
|
| 151 |
+
},
|
| 152 |
+
"length_penalty": 1.0,
|
| 153 |
+
"local_transformer_attn_implementation": "sdpa",
|
| 154 |
+
"local_transformer_layers": 1,
|
| 155 |
+
"max_length": 20,
|
| 156 |
+
"max_position_embeddings": 32768,
|
| 157 |
+
"min_length": 0,
|
| 158 |
+
"model_architecture": "global_local_transformer",
|
| 159 |
+
"model_type": "nano_tts",
|
| 160 |
+
"n_vq": 16,
|
| 161 |
+
"no_repeat_ngram_size": 0,
|
| 162 |
+
"num_beam_groups": 1,
|
| 163 |
+
"num_beams": 1,
|
| 164 |
+
"num_return_sequences": 1,
|
| 165 |
+
"output_attentions": false,
|
| 166 |
+
"output_hidden_states": false,
|
| 167 |
+
"output_scores": false,
|
| 168 |
+
"pad_token_id": 3,
|
| 169 |
+
"prefix": null,
|
| 170 |
+
"problem_type": null,
|
| 171 |
+
"pruned_heads": {},
|
| 172 |
+
"remove_invalid_values": false,
|
| 173 |
+
"repetition_penalty": 1.0,
|
| 174 |
+
"return_dict": true,
|
| 175 |
+
"return_dict_in_generate": false,
|
| 176 |
+
"sep_token_id": null,
|
| 177 |
+
"suppress_tokens": null,
|
| 178 |
+
"task_specific_params": null,
|
| 179 |
+
"temperature": 1.0,
|
| 180 |
+
"tf_legacy_loss": false,
|
| 181 |
+
"tie_encoder_decoder": false,
|
| 182 |
+
"tie_word_embeddings": true,
|
| 183 |
+
"tokenizer_class": "NanoTTSSentencePieceTokenizer",
|
| 184 |
+
"tokenizer_use_fast": false,
|
| 185 |
+
"top_k": 50,
|
| 186 |
+
"top_p": 1.0,
|
| 187 |
+
"torchscript": false,
|
| 188 |
+
"transformers_version": "4.57.1",
|
| 189 |
+
"typical_p": 1.0,
|
| 190 |
+
"use_bfloat16": false,
|
| 191 |
+
"vocab_size": 16384,
|
| 192 |
+
"auto_map": {
|
| 193 |
+
"AutoConfig": "configuration_nanotts.NanoTTSConfig",
|
| 194 |
+
"AutoModel": "modeling_nanotts_global_local.NanoTTSGlobalLocalForCausalLM",
|
| 195 |
+
"AutoModelForCausalLM": "modeling_nanotts_global_local.NanoTTSGlobalLocalForCausalLM"
|
| 196 |
+
}
|
| 197 |
+
}
|
weights/tts/configuration_nanotts.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from typing import Any, Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NanoTTSConfig(PretrainedConfig):
|
| 9 |
+
model_type = "nano_tts"
|
| 10 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
gpt2_config: Optional[Union[GPT2Config, Dict[str, Any]]] = None,
|
| 15 |
+
n_vq: int = 8,
|
| 16 |
+
audio_vocab_size: Optional[int] = 1024,
|
| 17 |
+
audio_codebook_sizes: Optional[list[int]] = None,
|
| 18 |
+
audio_pad_token_id: int = 1024,
|
| 19 |
+
pad_token_id: int = 151643,
|
| 20 |
+
im_start_token_id: int = 151644,
|
| 21 |
+
im_end_token_id: int = 151645,
|
| 22 |
+
audio_start_token_id: int = 151652,
|
| 23 |
+
audio_end_token_id: int = 151653,
|
| 24 |
+
audio_user_slot_token_id: int = 151654,
|
| 25 |
+
audio_assistant_slot_token_id: int = 151656,
|
| 26 |
+
tokenizer_use_fast: bool = False,
|
| 27 |
+
audio_tokenizer_type: str = "moss-audio-tokenizer-nano",
|
| 28 |
+
audio_tokenizer_pretrained_name_or_path: Optional[str] = "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano",
|
| 29 |
+
audio_tokenizer_sample_rate: int = 48000,
|
| 30 |
+
attn_implementation: str = "flash_attention_2",
|
| 31 |
+
initializer_range: float = 0.02,
|
| 32 |
+
model_architecture: str = "global_local_transformer",
|
| 33 |
+
local_transformer_layers: int = 4,
|
| 34 |
+
local_transformer_attn_implementation: Optional[str] = None,
|
| 35 |
+
**kwargs: Any,
|
| 36 |
+
) -> None:
|
| 37 |
+
if isinstance(gpt2_config, dict):
|
| 38 |
+
self.gpt2_config = GPT2Config(**gpt2_config)
|
| 39 |
+
elif gpt2_config is None:
|
| 40 |
+
self.gpt2_config = GPT2Config()
|
| 41 |
+
else:
|
| 42 |
+
self.gpt2_config = gpt2_config
|
| 43 |
+
|
| 44 |
+
self.n_vq = int(n_vq)
|
| 45 |
+
if audio_codebook_sizes is None:
|
| 46 |
+
if audio_vocab_size is None:
|
| 47 |
+
raise ValueError("audio_vocab_size must be set when audio_codebook_sizes is not provided.")
|
| 48 |
+
resolved_audio_codebook_sizes = [int(audio_vocab_size)] * self.n_vq
|
| 49 |
+
else:
|
| 50 |
+
resolved_audio_codebook_sizes = [int(codebook_size) for codebook_size in audio_codebook_sizes]
|
| 51 |
+
if len(resolved_audio_codebook_sizes) != self.n_vq:
|
| 52 |
+
raise ValueError(
|
| 53 |
+
"audio_codebook_sizes must have length n_vq "
|
| 54 |
+
f"(expected {self.n_vq}, got {len(resolved_audio_codebook_sizes)})."
|
| 55 |
+
)
|
| 56 |
+
if any(codebook_size <= 0 for codebook_size in resolved_audio_codebook_sizes):
|
| 57 |
+
raise ValueError("audio_codebook_sizes must contain positive integers.")
|
| 58 |
+
|
| 59 |
+
max_audio_codebook_size = max(resolved_audio_codebook_sizes)
|
| 60 |
+
if audio_vocab_size is not None and int(audio_vocab_size) < max_audio_codebook_size:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
"audio_vocab_size must be >= max(audio_codebook_sizes) "
|
| 63 |
+
f"(got {audio_vocab_size}, expected at least {max_audio_codebook_size})."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.audio_codebook_sizes = resolved_audio_codebook_sizes
|
| 67 |
+
self.audio_vocab_size = (
|
| 68 |
+
max_audio_codebook_size if audio_vocab_size is None else int(audio_vocab_size)
|
| 69 |
+
)
|
| 70 |
+
self.audio_pad_token_id = int(audio_pad_token_id)
|
| 71 |
+
if self.audio_pad_token_id < max_audio_codebook_size:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"audio_pad_token_id must be >= max(audio_codebook_sizes) so pad stays outside every codebook "
|
| 74 |
+
f"(got {self.audio_pad_token_id}, max codebook size {max_audio_codebook_size})."
|
| 75 |
+
)
|
| 76 |
+
self.pad_token_id = pad_token_id
|
| 77 |
+
self.im_start_token_id = im_start_token_id
|
| 78 |
+
self.im_end_token_id = im_end_token_id
|
| 79 |
+
self.audio_start_token_id = audio_start_token_id
|
| 80 |
+
self.audio_end_token_id = audio_end_token_id
|
| 81 |
+
self.audio_user_slot_token_id = audio_user_slot_token_id
|
| 82 |
+
self.audio_assistant_slot_token_id = audio_assistant_slot_token_id
|
| 83 |
+
self.tokenizer_use_fast = tokenizer_use_fast
|
| 84 |
+
self.audio_tokenizer_type = audio_tokenizer_type
|
| 85 |
+
self.audio_tokenizer_pretrained_name_or_path = audio_tokenizer_pretrained_name_or_path
|
| 86 |
+
self.audio_tokenizer_sample_rate = audio_tokenizer_sample_rate
|
| 87 |
+
self.attn_implementation = attn_implementation
|
| 88 |
+
self.initializer_range = initializer_range
|
| 89 |
+
self.model_architecture = model_architecture
|
| 90 |
+
self.local_transformer_layers = local_transformer_layers
|
| 91 |
+
self.local_transformer_attn_implementation = (
|
| 92 |
+
attn_implementation
|
| 93 |
+
if local_transformer_attn_implementation is None
|
| 94 |
+
else local_transformer_attn_implementation
|
| 95 |
+
)
|
| 96 |
+
self.vocab_size = self.gpt2_config.vocab_size
|
| 97 |
+
self.hidden_size = self.gpt2_config.hidden_size
|
| 98 |
+
self.max_position_embeddings = self.gpt2_config.n_positions
|
| 99 |
+
|
| 100 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 101 |
+
|
| 102 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 103 |
+
output = super().to_dict()
|
| 104 |
+
output["gpt2_config"] = self.gpt2_config.to_dict()
|
| 105 |
+
return output
|
weights/tts/gpt2_decoder.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
from transformers.activations import ACT2FN
|
| 11 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 12 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 16 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 17 |
+
|
| 18 |
+
_FLASH_ATTN_AVAILABLE = True
|
| 19 |
+
except Exception:
|
| 20 |
+
flash_attn_func = None
|
| 21 |
+
flash_attn_varlen_func = None
|
| 22 |
+
pad_input = None
|
| 23 |
+
unpad_input = None
|
| 24 |
+
_FLASH_ATTN_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class PackedSequenceMetadata:
|
| 29 |
+
cu_seqlens: torch.Tensor
|
| 30 |
+
max_seqlen: int
|
| 31 |
+
indices: Optional[torch.Tensor] = None
|
| 32 |
+
batch_size: Optional[int] = None
|
| 33 |
+
seq_len: Optional[int] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class NanoGPT2RotaryEmbedding(nn.Module):
|
| 37 |
+
def __init__(self, dim: int, base: float = 10000.0) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
if dim % 2 != 0:
|
| 40 |
+
raise ValueError(f"RoPE head_dim must be even, got {dim}")
|
| 41 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 42 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 43 |
+
|
| 44 |
+
def forward(
|
| 45 |
+
self,
|
| 46 |
+
position_ids: torch.LongTensor,
|
| 47 |
+
*,
|
| 48 |
+
device: torch.device,
|
| 49 |
+
dtype: torch.dtype,
|
| 50 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 51 |
+
if position_ids.ndim == 1:
|
| 52 |
+
position_ids = position_ids.unsqueeze(0)
|
| 53 |
+
freqs = torch.einsum("bs,d->bsd", position_ids.to(device=device, dtype=self.inv_freq.dtype), self.inv_freq)
|
| 54 |
+
cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
|
| 55 |
+
sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
|
| 56 |
+
return cos, sin
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
even = hidden_states[..., ::2]
|
| 61 |
+
odd = hidden_states[..., 1::2]
|
| 62 |
+
return torch.stack((-odd, even), dim=-1).reshape_as(hidden_states)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def apply_rotary_pos_emb(
|
| 66 |
+
hidden_states: torch.Tensor,
|
| 67 |
+
cos: torch.Tensor,
|
| 68 |
+
sin: torch.Tensor,
|
| 69 |
+
) -> torch.Tensor:
|
| 70 |
+
return (hidden_states * cos) + (rotate_half(hidden_states) * sin)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class NanoGPT2MLP(nn.Module):
|
| 74 |
+
def __init__(self, config: GPT2Config) -> None:
|
| 75 |
+
super().__init__()
|
| 76 |
+
hidden_size = int(config.hidden_size)
|
| 77 |
+
inner_size = int(config.n_inner or 4 * hidden_size)
|
| 78 |
+
self.fc_in = nn.Linear(hidden_size, inner_size)
|
| 79 |
+
self.fc_out = nn.Linear(inner_size, hidden_size)
|
| 80 |
+
self.act = ACT2FN[config.activation_function]
|
| 81 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 82 |
+
|
| 83 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
hidden_states = self.fc_in(hidden_states)
|
| 85 |
+
hidden_states = self.act(hidden_states)
|
| 86 |
+
hidden_states = self.fc_out(hidden_states)
|
| 87 |
+
return self.dropout(hidden_states)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class NanoGPT2Attention(nn.Module):
|
| 91 |
+
def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
|
| 92 |
+
super().__init__()
|
| 93 |
+
hidden_size = int(config.hidden_size)
|
| 94 |
+
num_heads = int(config.num_attention_heads)
|
| 95 |
+
if hidden_size % num_heads != 0:
|
| 96 |
+
raise ValueError(f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_heads}")
|
| 97 |
+
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
self.head_dim = hidden_size // num_heads
|
| 100 |
+
self.embed_dim = hidden_size
|
| 101 |
+
self.layer_idx = layer_idx
|
| 102 |
+
self.attn_implementation = attn_implementation
|
| 103 |
+
self.attn_dropout = float(config.attn_pdrop)
|
| 104 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 105 |
+
self.scale_attn_weights = bool(getattr(config, "scale_attn_weights", True))
|
| 106 |
+
self.scale_attn_by_inverse_layer_idx = bool(getattr(config, "scale_attn_by_inverse_layer_idx", False))
|
| 107 |
+
self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
|
| 108 |
+
if self.position_embedding_type not in {"absolute", "rope"}:
|
| 109 |
+
raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
|
| 110 |
+
|
| 111 |
+
self.c_attn = nn.Linear(hidden_size, 3 * hidden_size)
|
| 112 |
+
self.c_proj = nn.Linear(hidden_size, hidden_size)
|
| 113 |
+
self.rotary_emb = None
|
| 114 |
+
if self.position_embedding_type == "rope":
|
| 115 |
+
self.rotary_emb = NanoGPT2RotaryEmbedding(
|
| 116 |
+
self.head_dim,
|
| 117 |
+
base=float(getattr(config, "rope_base", 10000.0)),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def _split_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
if tensor.ndim == 3:
|
| 122 |
+
batch_size, seq_len, _ = tensor.shape
|
| 123 |
+
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 124 |
+
if tensor.ndim == 2:
|
| 125 |
+
total_tokens, _ = tensor.shape
|
| 126 |
+
return tensor.view(total_tokens, self.num_heads, self.head_dim)
|
| 127 |
+
raise ValueError(f"Unsupported tensor rank for attention split: {tensor.ndim}")
|
| 128 |
+
|
| 129 |
+
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
if tensor.ndim == 4:
|
| 131 |
+
batch_size, seq_len, _, _ = tensor.shape
|
| 132 |
+
return tensor.reshape(batch_size, seq_len, self.embed_dim)
|
| 133 |
+
if tensor.ndim == 3:
|
| 134 |
+
total_tokens, _, _ = tensor.shape
|
| 135 |
+
return tensor.reshape(total_tokens, self.embed_dim)
|
| 136 |
+
raise ValueError(f"Unsupported tensor rank for attention merge: {tensor.ndim}")
|
| 137 |
+
|
| 138 |
+
def _causal_attention_mask(
|
| 139 |
+
self,
|
| 140 |
+
attention_mask: Optional[torch.Tensor],
|
| 141 |
+
query_length: int,
|
| 142 |
+
key_length: int,
|
| 143 |
+
device: torch.device,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
query_positions = torch.arange(query_length, device=device, dtype=torch.long)
|
| 146 |
+
query_positions = query_positions + max(key_length - query_length, 0)
|
| 147 |
+
key_positions = torch.arange(key_length, device=device, dtype=torch.long)
|
| 148 |
+
causal = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
|
| 149 |
+
causal = causal.unsqueeze(0).unsqueeze(0)
|
| 150 |
+
if attention_mask is None:
|
| 151 |
+
return causal
|
| 152 |
+
key_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
|
| 153 |
+
return causal & key_mask
|
| 154 |
+
|
| 155 |
+
def _eager_attention(
|
| 156 |
+
self,
|
| 157 |
+
query: torch.Tensor,
|
| 158 |
+
key: torch.Tensor,
|
| 159 |
+
value: torch.Tensor,
|
| 160 |
+
attention_mask: Optional[torch.Tensor],
|
| 161 |
+
) -> torch.Tensor:
|
| 162 |
+
query = query.transpose(1, 2)
|
| 163 |
+
key = key.transpose(1, 2)
|
| 164 |
+
value = value.transpose(1, 2)
|
| 165 |
+
|
| 166 |
+
scale = 1.0
|
| 167 |
+
if self.scale_attn_weights:
|
| 168 |
+
scale /= self.head_dim ** 0.5
|
| 169 |
+
if self.scale_attn_by_inverse_layer_idx:
|
| 170 |
+
scale /= float(self.layer_idx + 1)
|
| 171 |
+
|
| 172 |
+
scores = torch.matmul(query, key.transpose(-1, -2)) * scale
|
| 173 |
+
causal_mask = self._causal_attention_mask(
|
| 174 |
+
attention_mask=attention_mask,
|
| 175 |
+
query_length=query.shape[-2],
|
| 176 |
+
key_length=key.shape[-2],
|
| 177 |
+
device=query.device,
|
| 178 |
+
)
|
| 179 |
+
scores = scores.masked_fill(~causal_mask, torch.finfo(scores.dtype).min)
|
| 180 |
+
probs = torch.softmax(scores, dim=-1)
|
| 181 |
+
if self.training and self.attn_dropout > 0:
|
| 182 |
+
probs = torch.dropout(probs, self.attn_dropout, train=True)
|
| 183 |
+
output = torch.matmul(probs, value)
|
| 184 |
+
return output.transpose(1, 2).contiguous()
|
| 185 |
+
|
| 186 |
+
def _sdpa_attention(
|
| 187 |
+
self,
|
| 188 |
+
query: torch.Tensor,
|
| 189 |
+
key: torch.Tensor,
|
| 190 |
+
value: torch.Tensor,
|
| 191 |
+
attention_mask: Optional[torch.Tensor],
|
| 192 |
+
) -> torch.Tensor:
|
| 193 |
+
query = query.transpose(1, 2)
|
| 194 |
+
key = key.transpose(1, 2)
|
| 195 |
+
value = value.transpose(1, 2)
|
| 196 |
+
mask = None
|
| 197 |
+
query_attention_mask = None
|
| 198 |
+
if attention_mask is not None:
|
| 199 |
+
query_length = query.shape[-2]
|
| 200 |
+
key_length = key.shape[-2]
|
| 201 |
+
mask = self._causal_attention_mask(
|
| 202 |
+
attention_mask=attention_mask,
|
| 203 |
+
query_length=query_length,
|
| 204 |
+
key_length=key_length,
|
| 205 |
+
device=query.device,
|
| 206 |
+
)
|
| 207 |
+
query_attention_mask = attention_mask[:, -query_length:].to(dtype=torch.bool, device=query.device)
|
| 208 |
+
if not bool(query_attention_mask.all()):
|
| 209 |
+
# SDPA can produce NaNs when a query row is fully masked. For padded query positions,
|
| 210 |
+
# keep a single aligned key visible, then zero the query output after attention.
|
| 211 |
+
mask = mask.expand(query.shape[0], -1, -1, -1).clone()
|
| 212 |
+
invalid_batch, invalid_query = torch.nonzero(~query_attention_mask, as_tuple=True)
|
| 213 |
+
aligned_key = invalid_query + max(key_length - query_length, 0)
|
| 214 |
+
mask[invalid_batch, :, invalid_query, aligned_key] = True
|
| 215 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
| 216 |
+
query,
|
| 217 |
+
key,
|
| 218 |
+
value,
|
| 219 |
+
attn_mask=mask,
|
| 220 |
+
dropout_p=self.attn_dropout if self.training else 0.0,
|
| 221 |
+
is_causal=mask is None,
|
| 222 |
+
)
|
| 223 |
+
if query_attention_mask is not None and not bool(query_attention_mask.all()):
|
| 224 |
+
output = output.masked_fill(~query_attention_mask[:, None, :, None], 0.0)
|
| 225 |
+
return output.transpose(1, 2).contiguous()
|
| 226 |
+
|
| 227 |
+
def _flash_attention(
|
| 228 |
+
self,
|
| 229 |
+
query: torch.Tensor,
|
| 230 |
+
key: torch.Tensor,
|
| 231 |
+
value: torch.Tensor,
|
| 232 |
+
attention_mask: Optional[torch.Tensor],
|
| 233 |
+
packed_metadata: Optional[PackedSequenceMetadata],
|
| 234 |
+
) -> torch.Tensor:
|
| 235 |
+
if not _FLASH_ATTN_AVAILABLE:
|
| 236 |
+
raise ImportError("flash_attn is not installed, but attn_implementation='flash_attention_2' was requested.")
|
| 237 |
+
if query.device.type != "cuda":
|
| 238 |
+
raise ValueError("flash_attention_2 requires CUDA tensors.")
|
| 239 |
+
if query.dtype not in (torch.float16, torch.bfloat16):
|
| 240 |
+
raise ValueError(
|
| 241 |
+
f"flash_attention_2 requires fp16/bf16 tensors, but received dtype={query.dtype}."
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
dropout_p = self.attn_dropout if self.training else 0.0
|
| 245 |
+
if packed_metadata is not None:
|
| 246 |
+
if packed_metadata.indices is not None:
|
| 247 |
+
query = query.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 248 |
+
key = key.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 249 |
+
value = value.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 250 |
+
output = flash_attn_varlen_func(
|
| 251 |
+
query,
|
| 252 |
+
key,
|
| 253 |
+
value,
|
| 254 |
+
packed_metadata.cu_seqlens,
|
| 255 |
+
packed_metadata.cu_seqlens,
|
| 256 |
+
packed_metadata.max_seqlen,
|
| 257 |
+
packed_metadata.max_seqlen,
|
| 258 |
+
dropout_p=dropout_p,
|
| 259 |
+
causal=True,
|
| 260 |
+
)
|
| 261 |
+
if packed_metadata.indices is None:
|
| 262 |
+
return output
|
| 263 |
+
return pad_input(
|
| 264 |
+
output,
|
| 265 |
+
packed_metadata.indices,
|
| 266 |
+
packed_metadata.batch_size,
|
| 267 |
+
packed_metadata.seq_len,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if attention_mask is None or bool(attention_mask.all()):
|
| 271 |
+
return flash_attn_func(
|
| 272 |
+
query,
|
| 273 |
+
key,
|
| 274 |
+
value,
|
| 275 |
+
dropout_p=dropout_p,
|
| 276 |
+
causal=True,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
unpadded_query, indices, cu_seqlens, max_seqlen, _ = unpad_input(query, attention_mask)
|
| 280 |
+
unpadded_key, _, _, _, _ = unpad_input(key, attention_mask)
|
| 281 |
+
unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
|
| 282 |
+
output = flash_attn_varlen_func(
|
| 283 |
+
unpadded_query,
|
| 284 |
+
unpadded_key,
|
| 285 |
+
unpadded_value,
|
| 286 |
+
cu_seqlens,
|
| 287 |
+
cu_seqlens,
|
| 288 |
+
max_seqlen,
|
| 289 |
+
max_seqlen,
|
| 290 |
+
dropout_p=dropout_p,
|
| 291 |
+
causal=True,
|
| 292 |
+
)
|
| 293 |
+
return pad_input(output, indices, query.shape[0], query.shape[1])
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self,
|
| 297 |
+
hidden_states: torch.Tensor,
|
| 298 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 299 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 300 |
+
packed_metadata: Optional[PackedSequenceMetadata] = None,
|
| 301 |
+
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 302 |
+
use_cache: bool = False,
|
| 303 |
+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 304 |
+
qkv = self.c_attn(hidden_states)
|
| 305 |
+
query, key, value = qkv.split(self.embed_dim, dim=-1)
|
| 306 |
+
query = self._split_heads(query)
|
| 307 |
+
key = self._split_heads(key)
|
| 308 |
+
value = self._split_heads(value)
|
| 309 |
+
|
| 310 |
+
if self.rotary_emb is not None:
|
| 311 |
+
if position_ids is None:
|
| 312 |
+
raise ValueError("position_ids must be provided when position_embedding_type='rope'.")
|
| 313 |
+
cos, sin = self.rotary_emb(
|
| 314 |
+
position_ids.to(device=query.device),
|
| 315 |
+
device=query.device,
|
| 316 |
+
dtype=query.dtype,
|
| 317 |
+
)
|
| 318 |
+
query = apply_rotary_pos_emb(query, cos, sin)
|
| 319 |
+
key = apply_rotary_pos_emb(key, cos, sin)
|
| 320 |
+
|
| 321 |
+
if layer_past is not None:
|
| 322 |
+
past_key, past_value = layer_past
|
| 323 |
+
key = torch.cat([past_key.to(device=key.device, dtype=key.dtype), key], dim=1)
|
| 324 |
+
value = torch.cat([past_value.to(device=value.device, dtype=value.dtype), value], dim=1)
|
| 325 |
+
|
| 326 |
+
present = (key, value) if use_cache else None
|
| 327 |
+
|
| 328 |
+
if self.attn_implementation == "flash_attention_2" and layer_past is None:
|
| 329 |
+
attn_output = self._flash_attention(
|
| 330 |
+
query=query,
|
| 331 |
+
key=key,
|
| 332 |
+
value=value,
|
| 333 |
+
attention_mask=attention_mask,
|
| 334 |
+
packed_metadata=packed_metadata,
|
| 335 |
+
)
|
| 336 |
+
elif self.attn_implementation == "sdpa":
|
| 337 |
+
attn_output = self._sdpa_attention(
|
| 338 |
+
query=query,
|
| 339 |
+
key=key,
|
| 340 |
+
value=value,
|
| 341 |
+
attention_mask=attention_mask,
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
attn_output = self._eager_attention(
|
| 345 |
+
query=query,
|
| 346 |
+
key=key,
|
| 347 |
+
value=value,
|
| 348 |
+
attention_mask=attention_mask,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
attn_output = self._merge_heads(attn_output)
|
| 352 |
+
attn_output = self.c_proj(attn_output)
|
| 353 |
+
return self.resid_dropout(attn_output), present
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class NanoGPT2Block(nn.Module):
|
| 357 |
+
def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
|
| 358 |
+
super().__init__()
|
| 359 |
+
hidden_size = int(config.hidden_size)
|
| 360 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 361 |
+
self.attn = NanoGPT2Attention(config, layer_idx=layer_idx, attn_implementation=attn_implementation)
|
| 362 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 363 |
+
self.mlp = NanoGPT2MLP(config)
|
| 364 |
+
|
| 365 |
+
def forward(
|
| 366 |
+
self,
|
| 367 |
+
hidden_states: torch.Tensor,
|
| 368 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 369 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 370 |
+
packed_metadata: Optional[PackedSequenceMetadata] = None,
|
| 371 |
+
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 372 |
+
use_cache: bool = False,
|
| 373 |
+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 374 |
+
attn_output, present = self.attn(
|
| 375 |
+
self.ln_1(hidden_states),
|
| 376 |
+
attention_mask=attention_mask,
|
| 377 |
+
position_ids=position_ids,
|
| 378 |
+
packed_metadata=packed_metadata,
|
| 379 |
+
layer_past=layer_past,
|
| 380 |
+
use_cache=use_cache,
|
| 381 |
+
)
|
| 382 |
+
hidden_states = hidden_states + attn_output
|
| 383 |
+
hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
|
| 384 |
+
return hidden_states, present
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class NanoGPT2Model(nn.Module):
|
| 388 |
+
def __init__(self, config: GPT2Config, attn_implementation: str = "eager") -> None:
|
| 389 |
+
super().__init__()
|
| 390 |
+
self.config = config
|
| 391 |
+
self.attn_implementation = attn_implementation
|
| 392 |
+
self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
|
| 393 |
+
if self.position_embedding_type not in {"absolute", "rope"}:
|
| 394 |
+
raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
|
| 395 |
+
hidden_size = int(config.hidden_size)
|
| 396 |
+
self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
| 397 |
+
self.wpe = nn.Embedding(config.n_positions, hidden_size) if self.position_embedding_type == "absolute" else nn.Identity()
|
| 398 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 399 |
+
self.h = nn.ModuleList(
|
| 400 |
+
[NanoGPT2Block(config, layer_idx=index, attn_implementation=attn_implementation) for index in range(config.n_layer)]
|
| 401 |
+
)
|
| 402 |
+
self.ln_f = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 403 |
+
self.gradient_checkpointing = False
|
| 404 |
+
self._reset_parameters()
|
| 405 |
+
|
| 406 |
+
def _reset_parameters(self) -> None:
|
| 407 |
+
init_std = float(self.config.initializer_range)
|
| 408 |
+
for module in self.modules():
|
| 409 |
+
if isinstance(module, nn.Linear):
|
| 410 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 411 |
+
if module.bias is not None:
|
| 412 |
+
nn.init.zeros_(module.bias)
|
| 413 |
+
elif isinstance(module, nn.Embedding):
|
| 414 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 415 |
+
elif isinstance(module, nn.LayerNorm):
|
| 416 |
+
nn.init.ones_(module.weight)
|
| 417 |
+
nn.init.zeros_(module.bias)
|
| 418 |
+
|
| 419 |
+
@staticmethod
|
| 420 |
+
def _normalize_num_sequences(
|
| 421 |
+
cu_seqlens: torch.Tensor,
|
| 422 |
+
num_sequences: Optional[torch.Tensor],
|
| 423 |
+
device: torch.device,
|
| 424 |
+
) -> torch.Tensor:
|
| 425 |
+
if cu_seqlens.ndim == 1:
|
| 426 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 427 |
+
if num_sequences is None:
|
| 428 |
+
counts = []
|
| 429 |
+
for boundary in cu_seqlens:
|
| 430 |
+
diffs = boundary[1:] - boundary[:-1]
|
| 431 |
+
counts.append(int((diffs > 0).sum().item()))
|
| 432 |
+
return torch.tensor(counts, dtype=torch.int32, device=device)
|
| 433 |
+
if num_sequences.ndim == 0:
|
| 434 |
+
return num_sequences.unsqueeze(0)
|
| 435 |
+
return num_sequences
|
| 436 |
+
|
| 437 |
+
@staticmethod
|
| 438 |
+
def build_packed_position_ids(
|
| 439 |
+
attention_mask: Optional[torch.Tensor],
|
| 440 |
+
cu_seqlens: torch.Tensor,
|
| 441 |
+
num_sequences: Optional[torch.Tensor],
|
| 442 |
+
) -> torch.Tensor:
|
| 443 |
+
if cu_seqlens.ndim == 1:
|
| 444 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 445 |
+
batch_size, seq_len = cu_seqlens.shape[0], cu_seqlens.shape[1] - 1
|
| 446 |
+
device = cu_seqlens.device
|
| 447 |
+
position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
|
| 448 |
+
counts = NanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
|
| 449 |
+
for batch_index in range(batch_size):
|
| 450 |
+
sequence_count = int(counts[batch_index].item())
|
| 451 |
+
boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
|
| 452 |
+
for start, end in zip(boundaries[:-1], boundaries[1:]):
|
| 453 |
+
start = int(start)
|
| 454 |
+
end = int(end)
|
| 455 |
+
if end > start:
|
| 456 |
+
position_ids[batch_index, start:end] = torch.arange(end - start, device=device)
|
| 457 |
+
if attention_mask is not None:
|
| 458 |
+
position_ids = position_ids * attention_mask.to(dtype=position_ids.dtype)
|
| 459 |
+
return position_ids
|
| 460 |
+
|
| 461 |
+
@staticmethod
|
| 462 |
+
def build_packed_metadata(
|
| 463 |
+
hidden_states: torch.Tensor,
|
| 464 |
+
cu_seqlens: torch.Tensor,
|
| 465 |
+
num_sequences: Optional[torch.Tensor],
|
| 466 |
+
) -> PackedSequenceMetadata:
|
| 467 |
+
if cu_seqlens.ndim == 1:
|
| 468 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 469 |
+
device = hidden_states.device
|
| 470 |
+
counts = NanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
|
| 471 |
+
flat_indices = []
|
| 472 |
+
cumulative = [0]
|
| 473 |
+
max_seqlen = 0
|
| 474 |
+
seq_len = hidden_states.shape[1]
|
| 475 |
+
|
| 476 |
+
for batch_index in range(hidden_states.shape[0]):
|
| 477 |
+
sequence_count = int(counts[batch_index].item())
|
| 478 |
+
boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
|
| 479 |
+
for start, end in zip(boundaries[:-1], boundaries[1:]):
|
| 480 |
+
start = int(start)
|
| 481 |
+
end = int(end)
|
| 482 |
+
if end <= start:
|
| 483 |
+
continue
|
| 484 |
+
segment_indices = batch_index * seq_len + torch.arange(start, end, device=device)
|
| 485 |
+
flat_indices.append(segment_indices)
|
| 486 |
+
cumulative.append(cumulative[-1] + (end - start))
|
| 487 |
+
max_seqlen = max(max_seqlen, end - start)
|
| 488 |
+
|
| 489 |
+
if not flat_indices:
|
| 490 |
+
raise ValueError("cu_seqlens did not describe any non-empty packed sequences.")
|
| 491 |
+
|
| 492 |
+
indices = torch.cat(flat_indices, dim=0)
|
| 493 |
+
return PackedSequenceMetadata(
|
| 494 |
+
cu_seqlens=torch.tensor(cumulative, dtype=torch.int32, device=device),
|
| 495 |
+
max_seqlen=max_seqlen,
|
| 496 |
+
indices=indices,
|
| 497 |
+
batch_size=hidden_states.shape[0],
|
| 498 |
+
seq_len=hidden_states.shape[1],
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
def forward(
|
| 502 |
+
self,
|
| 503 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 504 |
+
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 505 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 506 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 507 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 508 |
+
use_cache: Optional[bool] = None,
|
| 509 |
+
output_attentions: Optional[bool] = None,
|
| 510 |
+
output_hidden_states: Optional[bool] = None,
|
| 511 |
+
return_dict: bool = True,
|
| 512 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 513 |
+
num_sequences: Optional[torch.Tensor] = None,
|
| 514 |
+
) -> BaseModelOutputWithPast:
|
| 515 |
+
del input_ids, output_attentions
|
| 516 |
+
|
| 517 |
+
if inputs_embeds is None:
|
| 518 |
+
raise ValueError("inputs_embeds must be provided.")
|
| 519 |
+
|
| 520 |
+
use_cache = bool(use_cache)
|
| 521 |
+
if use_cache and cu_seqlens is not None:
|
| 522 |
+
raise ValueError("use_cache=True is not supported together with cu_seqlens packing.")
|
| 523 |
+
|
| 524 |
+
hidden_states = inputs_embeds
|
| 525 |
+
if attention_mask is None:
|
| 526 |
+
attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
|
| 527 |
+
else:
|
| 528 |
+
attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_states.device)
|
| 529 |
+
query_attention_mask = attention_mask[:, -hidden_states.shape[1] :]
|
| 530 |
+
|
| 531 |
+
packed_metadata = None
|
| 532 |
+
if position_ids is None:
|
| 533 |
+
if cu_seqlens is not None:
|
| 534 |
+
position_ids = self.build_packed_position_ids(
|
| 535 |
+
attention_mask=attention_mask,
|
| 536 |
+
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
|
| 537 |
+
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
|
| 538 |
+
)
|
| 539 |
+
elif attention_mask is not None:
|
| 540 |
+
position_ids = attention_mask.long().cumsum(dim=-1) - 1
|
| 541 |
+
position_ids = position_ids.masked_fill(~attention_mask, 0)
|
| 542 |
+
position_ids = position_ids[:, -hidden_states.shape[1] :]
|
| 543 |
+
else:
|
| 544 |
+
past_length = 0
|
| 545 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
| 546 |
+
past_length = past_key_values[0][0].shape[1]
|
| 547 |
+
position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device, dtype=torch.long)
|
| 548 |
+
position_ids = position_ids + past_length
|
| 549 |
+
position_ids = position_ids.unsqueeze(0).expand(hidden_states.shape[0], -1)
|
| 550 |
+
|
| 551 |
+
if cu_seqlens is not None and self.attn_implementation == "flash_attention_2":
|
| 552 |
+
packed_metadata = self.build_packed_metadata(
|
| 553 |
+
hidden_states=hidden_states,
|
| 554 |
+
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
|
| 555 |
+
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
if self.position_embedding_type == "absolute":
|
| 559 |
+
hidden_states = hidden_states + self.wpe(position_ids)
|
| 560 |
+
hidden_states = self.drop(hidden_states)
|
| 561 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 562 |
+
|
| 563 |
+
all_hidden_states = () if output_hidden_states else None
|
| 564 |
+
presents = [] if use_cache else None
|
| 565 |
+
for layer_index, block in enumerate(self.h):
|
| 566 |
+
if output_hidden_states:
|
| 567 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 568 |
+
|
| 569 |
+
if self.gradient_checkpointing and self.training:
|
| 570 |
+
if use_cache:
|
| 571 |
+
raise ValueError("use_cache=True is not supported when gradient checkpointing is enabled during training.")
|
| 572 |
+
|
| 573 |
+
def custom_forward(*inputs):
|
| 574 |
+
output, _ = block(
|
| 575 |
+
inputs[0],
|
| 576 |
+
attention_mask=inputs[1],
|
| 577 |
+
position_ids=inputs[2],
|
| 578 |
+
packed_metadata=packed_metadata,
|
| 579 |
+
layer_past=None,
|
| 580 |
+
use_cache=False,
|
| 581 |
+
)
|
| 582 |
+
return output
|
| 583 |
+
|
| 584 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 585 |
+
custom_forward,
|
| 586 |
+
hidden_states,
|
| 587 |
+
attention_mask,
|
| 588 |
+
position_ids,
|
| 589 |
+
use_reentrant=False,
|
| 590 |
+
)
|
| 591 |
+
present = None
|
| 592 |
+
else:
|
| 593 |
+
hidden_states, present = block(
|
| 594 |
+
hidden_states,
|
| 595 |
+
attention_mask=attention_mask,
|
| 596 |
+
position_ids=position_ids,
|
| 597 |
+
packed_metadata=packed_metadata,
|
| 598 |
+
layer_past=None if past_key_values is None else past_key_values[layer_index],
|
| 599 |
+
use_cache=use_cache,
|
| 600 |
+
)
|
| 601 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 602 |
+
if presents is not None:
|
| 603 |
+
presents.append(present)
|
| 604 |
+
|
| 605 |
+
hidden_states = self.ln_f(hidden_states)
|
| 606 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 607 |
+
if output_hidden_states:
|
| 608 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 609 |
+
|
| 610 |
+
if not return_dict:
|
| 611 |
+
return (hidden_states, tuple(presents) if presents is not None else None, all_hidden_states, None)
|
| 612 |
+
|
| 613 |
+
return BaseModelOutputWithPast(
|
| 614 |
+
last_hidden_state=hidden_states,
|
| 615 |
+
past_key_values=tuple(presents) if presents is not None else None,
|
| 616 |
+
hidden_states=all_hidden_states,
|
| 617 |
+
attentions=None,
|
| 618 |
+
)
|
weights/tts/modeling_nanotts_global_local.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
weights/tts/prompting.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import List, Sequence
|
| 4 |
+
|
| 5 |
+
from .configuration_nanotts import NanoTTSConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
USER_ROLE_PREFIX = "user\n"
|
| 9 |
+
USER_TEMPLATE_REFERENCE_PREFIX = (
|
| 10 |
+
"<user_inst>\n"
|
| 11 |
+
"- Reference(s):\n"
|
| 12 |
+
)
|
| 13 |
+
USER_TEMPLATE_AFTER_REFERENCE = (
|
| 14 |
+
"\n- Instruction:\nNone\n"
|
| 15 |
+
"- Tokens:\nNone\n"
|
| 16 |
+
"- Quality:\nNone\n"
|
| 17 |
+
"- Sound Event:\nNone\n"
|
| 18 |
+
"- Ambient Sound:\nNone\n"
|
| 19 |
+
"- Language:\nNone\n"
|
| 20 |
+
"- Text:\n"
|
| 21 |
+
)
|
| 22 |
+
USER_TEMPLATE_PREFIX = USER_TEMPLATE_REFERENCE_PREFIX + "None" + USER_TEMPLATE_AFTER_REFERENCE
|
| 23 |
+
USER_TEMPLATE_SUFFIX = "\n</user_inst>"
|
| 24 |
+
ASSISTANT_TURN_PREFIX = "\n"
|
| 25 |
+
ASSISTANT_ROLE_PREFIX = "assistant\n"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def encode_text(tokenizer, text: str) -> List[int]:
|
| 29 |
+
try:
|
| 30 |
+
return list(tokenizer.encode(text, add_special_tokens=False))
|
| 31 |
+
except TypeError:
|
| 32 |
+
return list(tokenizer.encode(text))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def decode_text(tokenizer, token_ids: Sequence[int]) -> str:
|
| 36 |
+
try:
|
| 37 |
+
return str(
|
| 38 |
+
tokenizer.decode(
|
| 39 |
+
list(token_ids),
|
| 40 |
+
skip_special_tokens=False,
|
| 41 |
+
clean_up_tokenization_spaces=False,
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
except TypeError:
|
| 45 |
+
try:
|
| 46 |
+
return str(tokenizer.decode(list(token_ids), skip_special_tokens=False))
|
| 47 |
+
except TypeError:
|
| 48 |
+
return str(tokenizer.decode(list(token_ids)))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_user_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
|
| 52 |
+
return [config.im_start_token_id] + encode_text(tokenizer, USER_ROLE_PREFIX) + encode_text(
|
| 53 |
+
tokenizer,
|
| 54 |
+
USER_TEMPLATE_REFERENCE_PREFIX,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_user_prompt_after_reference(tokenizer) -> List[int]:
|
| 59 |
+
return encode_text(tokenizer, USER_TEMPLATE_AFTER_REFERENCE)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def build_assistant_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
|
| 63 |
+
return encode_text(tokenizer, USER_TEMPLATE_SUFFIX) + [config.im_end_token_id] + encode_text(
|
| 64 |
+
tokenizer,
|
| 65 |
+
ASSISTANT_TURN_PREFIX,
|
| 66 |
+
) + [config.im_start_token_id] + encode_text(
|
| 67 |
+
tokenizer,
|
| 68 |
+
ASSISTANT_ROLE_PREFIX,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
|
| 73 |
+
return (
|
| 74 |
+
build_user_prompt_prefix(tokenizer, config)
|
| 75 |
+
+ encode_text(tokenizer, "None")
|
| 76 |
+
+ build_user_prompt_after_reference(tokenizer)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_prompt_suffix(tokenizer, config: NanoTTSConfig) -> List[int]:
|
| 81 |
+
return build_assistant_prompt_prefix(tokenizer, config)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def build_prompt_token_ids(
|
| 85 |
+
tokenizer,
|
| 86 |
+
config: NanoTTSConfig,
|
| 87 |
+
text_token_ids: Sequence[int],
|
| 88 |
+
) -> List[int]:
|
| 89 |
+
return build_prompt_prefix(tokenizer, config) + [int(token_id) for token_id in text_token_ids] + build_prompt_suffix(
|
| 90 |
+
tokenizer,
|
| 91 |
+
config,
|
| 92 |
+
)
|
weights/tts/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24003f2f11ac8a2cbf70514db2d8f1c02fb451aa6b3c0bffc9da09f31cd7caa5
|
| 3 |
+
size 234693095
|
weights/tts/special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<pad>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<unk>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
weights/tts/tokenization_nanotts_sentencepiece.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import shutil
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import sentencepiece as spm
|
| 8 |
+
from transformers import PreTrainedTokenizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NanoTTSSentencePieceTokenizer(PreTrainedTokenizer):
|
| 15 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 16 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
vocab_file: str,
|
| 21 |
+
unk_token: str = "<unk>",
|
| 22 |
+
bos_token: str = "<s>",
|
| 23 |
+
eos_token: str = "</s>",
|
| 24 |
+
pad_token: str = "<pad>",
|
| 25 |
+
sp_model_kwargs: dict[str, Any] | None = None,
|
| 26 |
+
**kwargs,
|
| 27 |
+
) -> None:
|
| 28 |
+
self.vocab_file = str(vocab_file)
|
| 29 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else dict(sp_model_kwargs)
|
| 30 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 31 |
+
self.sp_model.Load(self.vocab_file)
|
| 32 |
+
super().__init__(
|
| 33 |
+
unk_token=unk_token,
|
| 34 |
+
bos_token=bos_token,
|
| 35 |
+
eos_token=eos_token,
|
| 36 |
+
pad_token=pad_token,
|
| 37 |
+
**kwargs,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def vocab_size(self) -> int:
|
| 42 |
+
return int(self.sp_model.get_piece_size())
|
| 43 |
+
|
| 44 |
+
def get_vocab(self) -> dict[str, int]:
|
| 45 |
+
vocab = {self.sp_model.id_to_piece(i): i for i in range(self.vocab_size)}
|
| 46 |
+
vocab.update(self.added_tokens_encoder)
|
| 47 |
+
return vocab
|
| 48 |
+
|
| 49 |
+
def _tokenize(self, text: str) -> list[str]:
|
| 50 |
+
return list(self.sp_model.encode(text, out_type=str))
|
| 51 |
+
|
| 52 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 53 |
+
token_id = int(self.sp_model.piece_to_id(token))
|
| 54 |
+
return token_id
|
| 55 |
+
|
| 56 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 57 |
+
return str(self.sp_model.id_to_piece(int(index)))
|
| 58 |
+
|
| 59 |
+
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
| 60 |
+
return str(self.sp_model.decode(tokens))
|
| 61 |
+
|
| 62 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
|
| 63 |
+
save_dir = Path(save_directory)
|
| 64 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
out_name = "tokenizer.model" if filename_prefix is None else f"{filename_prefix}-tokenizer.model"
|
| 66 |
+
out_path = save_dir / out_name
|
| 67 |
+
if Path(self.vocab_file).resolve() != out_path.resolve():
|
| 68 |
+
shutil.copyfile(self.vocab_file, out_path)
|
| 69 |
+
return (str(out_path),)
|
| 70 |
+
|
| 71 |
+
def build_inputs_with_special_tokens(
|
| 72 |
+
self,
|
| 73 |
+
token_ids_0: list[int],
|
| 74 |
+
token_ids_1: list[int] | None = None,
|
| 75 |
+
) -> list[int]:
|
| 76 |
+
if token_ids_1 is None:
|
| 77 |
+
return list(token_ids_0)
|
| 78 |
+
return list(token_ids_0) + list(token_ids_1)
|
| 79 |
+
|
| 80 |
+
def get_special_tokens_mask(
|
| 81 |
+
self,
|
| 82 |
+
token_ids_0: list[int],
|
| 83 |
+
token_ids_1: list[int] | None = None,
|
| 84 |
+
already_has_special_tokens: bool = False,
|
| 85 |
+
) -> list[int]:
|
| 86 |
+
if already_has_special_tokens:
|
| 87 |
+
return super().get_special_tokens_mask(
|
| 88 |
+
token_ids_0=token_ids_0,
|
| 89 |
+
token_ids_1=token_ids_1,
|
| 90 |
+
already_has_special_tokens=True,
|
| 91 |
+
)
|
| 92 |
+
if token_ids_1 is None:
|
| 93 |
+
return [0] * len(token_ids_0)
|
| 94 |
+
return [0] * (len(token_ids_0) + len(token_ids_1))
|
| 95 |
+
|
| 96 |
+
def create_token_type_ids_from_sequences(
|
| 97 |
+
self,
|
| 98 |
+
token_ids_0: list[int],
|
| 99 |
+
token_ids_1: list[int] | None = None,
|
| 100 |
+
) -> list[int]:
|
| 101 |
+
if token_ids_1 is None:
|
| 102 |
+
return [0] * len(token_ids_0)
|
| 103 |
+
return [0] * (len(token_ids_0) + len(token_ids_1))
|
weights/tts/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c353ee1479b536bf414c1b247f5542b6607fb8ae91320e5af1781fee200fddff
|
| 3 |
+
size 470897
|
weights/tts/tokenizer_config.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<unk>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<s>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "</s>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<pad>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"additional_special_tokens": [],
|
| 37 |
+
"auto_map": {
|
| 38 |
+
"AutoTokenizer": [
|
| 39 |
+
"tokenization_nanotts_sentencepiece.NanoTTSSentencePieceTokenizer",
|
| 40 |
+
null
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
"backend": "custom",
|
| 44 |
+
"bos_token": "<s>",
|
| 45 |
+
"clean_up_tokenization_spaces": false,
|
| 46 |
+
"eos_token": "</s>",
|
| 47 |
+
"extra_special_tokens": {},
|
| 48 |
+
"model_max_length": 16384,
|
| 49 |
+
"pad_token": "<pad>",
|
| 50 |
+
"tokenizer_class": "NanoTTSSentencePieceTokenizer",
|
| 51 |
+
"unk_token": "<unk>"
|
| 52 |
+
}
|