--replace-all commited on
Commit
1459ef5
·
1 Parent(s): a8bc6e2

Add Nano-TTS CPU Gradio Space

Browse files
Files changed (45) hide show
  1. .gitattributes +2 -0
  2. .gitignore +6 -0
  3. README.md +11 -3
  4. app.py +500 -0
  5. asserts/audio/en_1.wav +3 -0
  6. asserts/audio/en_2.wav +3 -0
  7. asserts/audio/en_3.wav +3 -0
  8. asserts/audio/en_4.wav +3 -0
  9. asserts/audio/en_5.wav +3 -0
  10. asserts/audio/jp_1.mp3 +3 -0
  11. asserts/audio/jp_2.wav +3 -0
  12. asserts/audio/jp_3.wav +3 -0
  13. asserts/audio/jp_4.wav +3 -0
  14. asserts/audio/jp_5.wav +3 -0
  15. asserts/audio/zh_1.wav +3 -0
  16. asserts/audio/zh_2.wav +3 -0
  17. asserts/audio/zh_3.wav +3 -0
  18. asserts/audio/zh_4.wav +3 -0
  19. asserts/audio/zh_5.wav +3 -0
  20. asserts/audio/zh_6.wav +3 -0
  21. nano_tts_runtime.py +727 -0
  22. requirements.txt +7 -0
  23. text_normalization_pipeline.py +195 -0
  24. tts_robust_normalizer_single_script.py +366 -0
  25. weights/codec/.gitattributes +35 -0
  26. weights/codec/README.md +195 -0
  27. weights/codec/__init__.py +1 -0
  28. weights/codec/config.json +304 -0
  29. weights/codec/configuration_moss_audio_tokenizer.py +467 -0
  30. weights/codec/model-00001-of-00001.safetensors +3 -0
  31. weights/codec/model.safetensors.index.json +382 -0
  32. weights/codec/modeling_moss_audio_tokenizer.py +0 -0
  33. weights/tts/.gitattributes +35 -0
  34. weights/tts/README.md +3 -0
  35. weights/tts/__init__.py +31 -0
  36. weights/tts/config.json +197 -0
  37. weights/tts/configuration_nanotts.py +105 -0
  38. weights/tts/gpt2_decoder.py +618 -0
  39. weights/tts/modeling_nanotts_global_local.py +0 -0
  40. weights/tts/prompting.py +92 -0
  41. weights/tts/pytorch_model.bin +3 -0
  42. weights/tts/special_tokens_map.json +30 -0
  43. weights/tts/tokenization_nanotts_sentencepiece.py +103 -0
  44. weights/tts/tokenizer.model +3 -0
  45. weights/tts/tokenizer_config.json +52 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
37
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ generated_audio/
4
+ .cache/
5
+ weights/tts/.cache/
6
+ weights/codec/.cache/
README.md CHANGED
@@ -4,11 +4,19 @@ emoji: 📈
4
  colorFrom: red
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 6.11.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: space for MOSS-TTS-Nano
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
4
  colorFrom: red
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 6.5.1
8
+ python_version: "3.10"
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
+ short_description: CPU-only MOSS TTS Nano Gradio demo with local TTS and codec weights
13
  ---
14
 
15
+ This Space runs Nano-TTS on CPU using the local `weights/tts` and `weights/codec` directories.
16
+
17
+ Supported modes:
18
+
19
+ - `voice_clone`: upload a reference audio file or use a built-in preset voice
20
+ - `continuation`: plain TTS, or prompt transcript plus prompt audio
21
+
22
+ Realtime streaming decode is disabled in this Space. Audio is returned after full synthesis finishes.
app.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import functools
5
+ import logging
6
+ import os
7
+ import time
8
+ from pathlib import Path
9
+
10
+ import gradio as gr
11
+
12
+ from nano_tts_runtime import DEFAULT_VOICE, NanoTTSService, build_default_voice_presets
13
+ from text_normalization_pipeline import prepare_tts_request_texts
14
+
15
+ APP_DIR = Path(__file__).resolve().parent
16
+ CHECKPOINT_PATH = APP_DIR / "weights" / "tts"
17
+ AUDIO_TOKENIZER_PATH = APP_DIR / "weights" / "codec"
18
+ OUTPUT_DIR = Path("/tmp") / "nano-tts-space"
19
+ PRELOAD_ENV_VAR = "NANO_TTS_PRELOAD_AT_STARTUP"
20
+
21
+ MODE_VOICE_CLONE = "voice_clone"
22
+ MODE_CONTINUATION = "continuation"
23
+
24
+ _VOICE_PRESETS = build_default_voice_presets()
25
+
26
+
27
+ def build_voice_choices() -> list[tuple[str, str]]:
28
+ preferred: list[tuple[str, str]] = []
29
+ fallback: list[tuple[str, str]] = []
30
+
31
+ for preset in _VOICE_PRESETS.values():
32
+ if not preset.prompt_audio_path.is_file():
33
+ continue
34
+
35
+ item = (f"{preset.name} - {preset.description}", preset.name)
36
+ fallback.append(item)
37
+ if preset.prompt_audio_path.suffix.lower() == ".wav":
38
+ preferred.append(item)
39
+
40
+ return preferred or fallback
41
+
42
+
43
+ VOICE_CHOICES = build_voice_choices()
44
+ DEFAULT_VOICE_VALUE = (
45
+ DEFAULT_VOICE
46
+ if any(value == DEFAULT_VOICE for _, value in VOICE_CHOICES)
47
+ else (VOICE_CHOICES[0][1] if VOICE_CHOICES else "")
48
+ )
49
+
50
+
51
+ def parse_bool_env(name: str, default: bool) -> bool:
52
+ value = os.getenv(name)
53
+ if value is None:
54
+ return default
55
+ return value.strip().lower() in {"1", "true", "yes", "y", "on"}
56
+
57
+
58
+ def parse_port(value: str | None, default: int) -> int:
59
+ if not value:
60
+ return default
61
+ try:
62
+ return int(value)
63
+ except ValueError:
64
+ return default
65
+
66
+
67
+ def maybe_delete_file(path: str | Path | None) -> None:
68
+ if not path:
69
+ return
70
+ try:
71
+ Path(path).unlink(missing_ok=True)
72
+ except OSError:
73
+ logging.warning("failed to delete temporary file: %s", path, exc_info=True)
74
+
75
+
76
+ @functools.lru_cache(maxsize=1)
77
+ def get_tts_service() -> NanoTTSService:
78
+ return NanoTTSService(
79
+ checkpoint_path=CHECKPOINT_PATH,
80
+ audio_tokenizer_path=AUDIO_TOKENIZER_PATH,
81
+ device="cpu",
82
+ dtype="float32",
83
+ attn_implementation="sdpa",
84
+ output_dir=OUTPUT_DIR,
85
+ )
86
+
87
+
88
+ def preload_service() -> None:
89
+ started_at = time.monotonic()
90
+ logging.info(
91
+ "preloading Nano-TTS model checkpoint=%s codec=%s device=cpu",
92
+ CHECKPOINT_PATH,
93
+ AUDIO_TOKENIZER_PATH,
94
+ )
95
+ get_tts_service().get_model()
96
+ logging.info("Nano-TTS preload finished in %.2fs", time.monotonic() - started_at)
97
+
98
+
99
+ def render_mode_hint(mode: str) -> str:
100
+ if mode == MODE_CONTINUATION:
101
+ return (
102
+ "Current mode: **Continuation** \n"
103
+ "Plain TTS uses only target text. If you upload reference audio, you must also provide its transcript."
104
+ )
105
+ return (
106
+ "Current mode: **Voice Clone** \n"
107
+ "Upload a reference audio file or use a built-in preset voice. Audio is returned only after full decoding."
108
+ )
109
+
110
+
111
+ def update_mode_ui(mode: str):
112
+ if mode == MODE_CONTINUATION:
113
+ return (
114
+ gr.update(visible=False),
115
+ gr.update(
116
+ visible=True,
117
+ value="",
118
+ placeholder="Only for continuation with reference audio.",
119
+ ),
120
+ gr.update(label="Reference Audio Upload (optional; required if Prompt Transcript is set)"),
121
+ render_mode_hint(mode),
122
+ )
123
+
124
+ return (
125
+ gr.update(visible=True),
126
+ gr.update(visible=False, value=""),
127
+ gr.update(label="Reference Audio Upload (optional; overrides preset voice)"),
128
+ render_mode_hint(mode),
129
+ )
130
+
131
+
132
+ def validate_request(
133
+ *,
134
+ text: str,
135
+ mode: str,
136
+ prompt_text: str,
137
+ prompt_audio_path: str | None,
138
+ ) -> tuple[str, str | None]:
139
+ normalized_text = str(text or "").strip()
140
+ normalized_prompt_text = str(prompt_text or "").strip()
141
+ has_prompt_audio = bool(prompt_audio_path)
142
+
143
+ if not normalized_text:
144
+ raise ValueError("Please enter text to synthesize.")
145
+
146
+ if mode == MODE_VOICE_CLONE:
147
+ if normalized_prompt_text:
148
+ raise ValueError("voice_clone mode does not use prompt_text. Leave Prompt Transcript empty.")
149
+ return normalized_text, None
150
+
151
+ if bool(normalized_prompt_text) != has_prompt_audio:
152
+ raise ValueError(
153
+ "continuation mode accepts either target text only, or prompt_text and reference audio together."
154
+ )
155
+
156
+ return normalized_text, (normalized_prompt_text or None)
157
+
158
+
159
+ def build_status_text(
160
+ *,
161
+ result: dict[str, object],
162
+ prepared_texts: dict[str, object],
163
+ reference_source: str,
164
+ ) -> str:
165
+ text_chunks = result.get("voice_clone_text_chunks") or []
166
+ chunk_count = len(text_chunks) if isinstance(text_chunks, list) and text_chunks else 1
167
+ return (
168
+ f"Done | mode={result['mode']} | ref={reference_source} | elapsed={result['elapsed_seconds']:.2f}s | "
169
+ f"sample_rate={result['sample_rate']} | attn={result['effective_global_attn_implementation']} | "
170
+ f"chunks={chunk_count} | normalization={prepared_texts['normalization_method']}"
171
+ )
172
+
173
+
174
+ def run_inference(
175
+ text: str,
176
+ mode: str,
177
+ voice: str,
178
+ prompt_audio_path: str | None,
179
+ prompt_text: str,
180
+ max_new_frames: int,
181
+ voice_clone_max_text_tokens: int,
182
+ do_sample: bool,
183
+ text_temperature: float,
184
+ text_top_p: float,
185
+ text_top_k: int,
186
+ audio_temperature: float,
187
+ audio_top_p: float,
188
+ audio_top_k: int,
189
+ audio_repetition_penalty: float,
190
+ seed: float | int,
191
+ ):
192
+ generated_audio_path: str | None = None
193
+ try:
194
+ normalized_text, normalized_prompt_text = validate_request(
195
+ text=text,
196
+ mode=mode,
197
+ prompt_text=prompt_text,
198
+ prompt_audio_path=prompt_audio_path,
199
+ )
200
+ prepared_texts = prepare_tts_request_texts(
201
+ text=normalized_text,
202
+ prompt_text=normalized_prompt_text or "",
203
+ voice=voice,
204
+ enable_wetext=False,
205
+ text_normalizer_manager=None,
206
+ )
207
+ reference_source = (
208
+ "uploaded_audio"
209
+ if prompt_audio_path
210
+ else (f"preset:{voice}" if mode == MODE_VOICE_CLONE else "none")
211
+ )
212
+ normalized_seed = None
213
+ if seed not in {"", None}:
214
+ resolved_seed = int(seed)
215
+ if resolved_seed != 0:
216
+ normalized_seed = resolved_seed
217
+
218
+ result = get_tts_service().synthesize(
219
+ text=str(prepared_texts["text"]),
220
+ mode=mode,
221
+ voice=voice,
222
+ prompt_audio_path=prompt_audio_path or None,
223
+ prompt_text=str(prepared_texts["prompt_text"]).strip() or None,
224
+ max_new_frames=int(max_new_frames),
225
+ voice_clone_max_text_tokens=int(voice_clone_max_text_tokens),
226
+ do_sample=bool(do_sample),
227
+ text_temperature=float(text_temperature),
228
+ text_top_p=float(text_top_p),
229
+ text_top_k=int(text_top_k),
230
+ audio_temperature=float(audio_temperature),
231
+ audio_top_p=float(audio_top_p),
232
+ audio_top_k=int(audio_top_k),
233
+ audio_repetition_penalty=float(audio_repetition_penalty),
234
+ seed=normalized_seed,
235
+ attn_implementation="sdpa",
236
+ )
237
+ generated_audio_path = str(result["audio_path"])
238
+ return (
239
+ (int(result["sample_rate"]), result["waveform_numpy"]),
240
+ build_status_text(
241
+ result=result,
242
+ prepared_texts=prepared_texts,
243
+ reference_source=reference_source,
244
+ ),
245
+ str(prepared_texts["normalized_text"]),
246
+ str(prepared_texts["normalized_prompt_text"]),
247
+ )
248
+ except Exception as exc:
249
+ logging.exception("Nano-TTS inference failed")
250
+ raise gr.Error(str(exc)) from exc
251
+ finally:
252
+ maybe_delete_file(generated_audio_path)
253
+
254
+
255
+ def build_demo():
256
+ custom_css = """
257
+ :root {
258
+ --bg: #f5f6f0;
259
+ --panel: #ffffff;
260
+ --ink: #15221a;
261
+ --muted: #5a695e;
262
+ --line: #d9dfd6;
263
+ --accent: #285943;
264
+ }
265
+ .gradio-container {
266
+ background:
267
+ radial-gradient(circle at top left, rgba(162, 198, 167, 0.18), transparent 28%),
268
+ linear-gradient(180deg, #f5f6f0 0%, #edf1ea 100%);
269
+ color: var(--ink);
270
+ }
271
+ .app-card {
272
+ border: 1px solid var(--line);
273
+ border-radius: 18px;
274
+ background: rgba(255, 255, 255, 0.96);
275
+ padding: 16px;
276
+ box-shadow: 0 20px 40px rgba(21, 34, 26, 0.06);
277
+ }
278
+ .app-title {
279
+ font-size: 24px;
280
+ font-weight: 700;
281
+ letter-spacing: 0.2px;
282
+ margin-bottom: 6px;
283
+ }
284
+ .app-subtitle {
285
+ color: var(--muted);
286
+ font-size: 14px;
287
+ line-height: 1.5;
288
+ }
289
+ #run-btn {
290
+ background: var(--accent);
291
+ border: none;
292
+ }
293
+ """
294
+
295
+ with gr.Blocks(title="Nano-TTS CPU Space", css=custom_css) as demo:
296
+ gr.Markdown(
297
+ """
298
+ <div class="app-card">
299
+ <div class="app-title">Nano-TTS CPU</div>
300
+ <div class="app-subtitle">
301
+ Hugging Face Space edition backed by local <code>weights/tts</code> and <code>weights/codec</code>.
302
+ Realtime streaming decode is disabled; audio is returned after full synthesis.
303
+ </div>
304
+ </div>
305
+ """
306
+ )
307
+
308
+ with gr.Row(equal_height=False):
309
+ with gr.Column(scale=3):
310
+ text = gr.Textbox(
311
+ label="Target Text",
312
+ lines=10,
313
+ placeholder="Enter the text to synthesize.",
314
+ )
315
+ mode = gr.Radio(
316
+ choices=[
317
+ ("Voice Clone", MODE_VOICE_CLONE),
318
+ ("Continuation", MODE_CONTINUATION),
319
+ ],
320
+ value=MODE_VOICE_CLONE,
321
+ label="Inference Mode",
322
+ )
323
+ mode_hint = gr.Markdown(render_mode_hint(MODE_VOICE_CLONE))
324
+ voice = gr.Dropdown(
325
+ choices=VOICE_CHOICES,
326
+ value=DEFAULT_VOICE_VALUE,
327
+ label="Preset Voice",
328
+ info="Used only by voice_clone when no reference audio is uploaded.",
329
+ visible=True,
330
+ )
331
+ prompt_audio = gr.Audio(
332
+ label="Reference Audio Upload (optional; overrides preset voice)",
333
+ type="filepath",
334
+ sources=["upload"],
335
+ )
336
+ prompt_text = gr.Textbox(
337
+ label="Prompt Transcript",
338
+ lines=3,
339
+ visible=False,
340
+ placeholder="Only for continuation with reference audio.",
341
+ )
342
+
343
+ gr.Markdown(
344
+ "Robust text normalization is always on. WeTextProcessing is disabled in this CPU Space for a simpler deployment path."
345
+ )
346
+
347
+ with gr.Accordion("Advanced Parameters", open=False):
348
+ max_new_frames = gr.Slider(
349
+ minimum=64,
350
+ maximum=512,
351
+ step=16,
352
+ value=375,
353
+ label="max_new_frames",
354
+ )
355
+ voice_clone_max_text_tokens = gr.Slider(
356
+ minimum=25,
357
+ maximum=200,
358
+ step=5,
359
+ value=75,
360
+ label="voice_clone_max_text_tokens",
361
+ )
362
+ do_sample = gr.Checkbox(
363
+ value=True,
364
+ label="Enable Sampling",
365
+ )
366
+ seed = gr.Number(
367
+ value=0,
368
+ precision=0,
369
+ label="Seed (0 = random)",
370
+ )
371
+ text_temperature = gr.Slider(
372
+ minimum=0.1,
373
+ maximum=2.0,
374
+ step=0.05,
375
+ value=1.0,
376
+ label="text_temperature",
377
+ )
378
+ text_top_p = gr.Slider(
379
+ minimum=0.1,
380
+ maximum=1.0,
381
+ step=0.01,
382
+ value=1.0,
383
+ label="text_top_p",
384
+ )
385
+ text_top_k = gr.Slider(
386
+ minimum=1,
387
+ maximum=100,
388
+ step=1,
389
+ value=50,
390
+ label="text_top_k",
391
+ )
392
+ audio_temperature = gr.Slider(
393
+ minimum=0.1,
394
+ maximum=2.0,
395
+ step=0.05,
396
+ value=0.8,
397
+ label="audio_temperature",
398
+ )
399
+ audio_top_p = gr.Slider(
400
+ minimum=0.1,
401
+ maximum=1.0,
402
+ step=0.01,
403
+ value=0.95,
404
+ label="audio_top_p",
405
+ )
406
+ audio_top_k = gr.Slider(
407
+ minimum=1,
408
+ maximum=100,
409
+ step=1,
410
+ value=25,
411
+ label="audio_top_k",
412
+ )
413
+ audio_repetition_penalty = gr.Slider(
414
+ minimum=0.8,
415
+ maximum=2.0,
416
+ step=0.05,
417
+ value=1.2,
418
+ label="audio_repetition_penalty",
419
+ )
420
+
421
+ run_btn = gr.Button("Generate Speech", variant="primary", elem_id="run-btn")
422
+
423
+ with gr.Column(scale=2):
424
+ output_audio = gr.Audio(label="Output Audio", type="numpy")
425
+ status = gr.Textbox(label="Status", lines=4, interactive=False)
426
+ normalized_text = gr.Textbox(label="Normalized Text", lines=6, interactive=False)
427
+ normalized_prompt_text = gr.Textbox(
428
+ label="Normalized Prompt Transcript",
429
+ lines=4,
430
+ interactive=False,
431
+ )
432
+
433
+ mode.change(
434
+ fn=update_mode_ui,
435
+ inputs=[mode],
436
+ outputs=[voice, prompt_text, prompt_audio, mode_hint],
437
+ )
438
+
439
+ run_btn.click(
440
+ fn=run_inference,
441
+ inputs=[
442
+ text,
443
+ mode,
444
+ voice,
445
+ prompt_audio,
446
+ prompt_text,
447
+ max_new_frames,
448
+ voice_clone_max_text_tokens,
449
+ do_sample,
450
+ text_temperature,
451
+ text_top_p,
452
+ text_top_k,
453
+ audio_temperature,
454
+ audio_top_p,
455
+ audio_top_k,
456
+ audio_repetition_penalty,
457
+ seed,
458
+ ],
459
+ outputs=[output_audio, status, normalized_text, normalized_prompt_text],
460
+ )
461
+
462
+ return demo
463
+
464
+
465
+ def main() -> None:
466
+ parser = argparse.ArgumentParser(description="Nano-TTS Hugging Face Space")
467
+ parser.add_argument("--host", type=str, default="0.0.0.0")
468
+ parser.add_argument(
469
+ "--port",
470
+ type=int,
471
+ default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))),
472
+ )
473
+ parser.add_argument("--share", action="store_true")
474
+ args = parser.parse_args()
475
+
476
+ logging.basicConfig(
477
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
478
+ level=logging.INFO,
479
+ )
480
+
481
+ args.host = os.getenv("GRADIO_SERVER_NAME", args.host)
482
+ args.port = parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), args.port)
483
+
484
+ preload_enabled = parse_bool_env(PRELOAD_ENV_VAR, default=not bool(os.getenv("SPACE_ID")))
485
+ if preload_enabled:
486
+ preload_service()
487
+ else:
488
+ logging.info("Skipping model preload (set %s=1 to enable).", PRELOAD_ENV_VAR)
489
+
490
+ demo = build_demo()
491
+ demo.queue(max_size=4, default_concurrency_limit=1).launch(
492
+ server_name=args.host,
493
+ server_port=args.port,
494
+ share=args.share,
495
+ ssr_mode=False,
496
+ )
497
+
498
+
499
+ if __name__ == "__main__":
500
+ main()
asserts/audio/en_1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1816ab428334ba2de49dcb8b0a10e17eb1835f7f1f7bcda13504e88f46bed1e8
3
+ size 249284
asserts/audio/en_2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:959cce9498f2bae964ca67136c2c02c7174922813b69aa435b27ec8759b44992
3
+ size 694618
asserts/audio/en_3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:563544e54f6dd66b24a4494fa40b8f9debd7cceb50ae47a149c14bc3610c4aff
3
+ size 455372
asserts/audio/en_4.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdd4f0ba5a4c0499f5194f5767ffaa9e988ea912210e369f66e2812278ba45ff
3
+ size 458948
asserts/audio/en_5.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0822692aafe818424d9902419ec46bd707bddc401ca1b5a2539229cfc2852e7
3
+ size 5303154
asserts/audio/jp_1.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f2cdb58d8050a77f09f5444f43cbd17d56bf9c73d75b98cd994feb2af22dc02
3
+ size 96624
asserts/audio/jp_2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c44a65e55b7376a87607fdea5a6a5ab735c7aef2e007d1fc02a9f50d37bf11a4
3
+ size 227600
asserts/audio/jp_3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:585ff999d7219d6247863a1abc3b112c822fb8603e546146d788bcf14536c57e
3
+ size 427120
asserts/audio/jp_4.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58247506a362aaa347bc732d0196078e72b434046b9ddf3111c30878cdc10213
3
+ size 546884
asserts/audio/jp_5.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd4a8ef2dc90f080ec8e6abb40d4b3d40c3445d51e57a5244ee46dbba1b2dcf8
3
+ size 346670
asserts/audio/zh_1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24c0cc85603d26017ac9c3ee89e0a03c66a193a5fdede5db74bb88f670d83723
3
+ size 2000754
asserts/audio/zh_2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92b9312ca9fbb6f351bc57ae123118a28bd773cfd62dda9fc59f372cea786143
3
+ size 442068
asserts/audio/zh_3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:168d988f2d60773902862e5fd29fb0fad10468925b10a98995f6feb44ceb1cff
3
+ size 411452
asserts/audio/zh_4.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:965ba9c61ffbc4dc03b6441a5e22d08d26a747ff3536d669898d3975aebc8e72
3
+ size 1267100
asserts/audio/zh_5.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89247d4db86a7dd921f16f805fbe513e7dd12631e5402aea02a94b4fa19560e7
3
+ size 827036
asserts/audio/zh_6.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e8b1a0edd604129a6b8bef1a2f3627ad7c4c6069ebb77e50e88781a4048c9c1
3
+ size 285092
nano_tts_runtime.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import logging
5
+ import threading
6
+ import time
7
+ import uuid
8
+ from dataclasses import dataclass
9
+ from functools import lru_cache
10
+ from pathlib import Path
11
+ from typing import Iterator, Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+ from transformers import AutoModel, AutoModelForCausalLM
16
+
17
+ MOSS_AUDIO_TOKENIZER_TYPE = "moss-audio-tokenizer-nano"
18
+
19
+ APP_DIR = Path(__file__).resolve().parent
20
+ DEFAULT_CHECKPOINT_PATH = APP_DIR / "weights" / "tts"
21
+ DEFAULT_AUDIO_TOKENIZER_PATH = APP_DIR / "weights" / "codec"
22
+ DEFAULT_PROMPT_AUDIO_DIR = APP_DIR / "asserts" / "audio"
23
+ DEFAULT_OUTPUT_DIR = APP_DIR / "generated_audio"
24
+
25
+ _DEFAULT_VOICE_FILES: dict[str, tuple[str, str]] = {
26
+ "Junhao": ("zh_1.wav", "Chinese male voice A"),
27
+ "Zhiming": ("zh_2.wav", "Chinese male voice B"),
28
+ "Weiguo": ("zh_5.wav", "Chinese male voice C"),
29
+ "Xiaoyu": ("zh_3.wav", "Chinese female voice A"),
30
+ "Yuewen": ("zh_4.wav", "Chinese female voice B"),
31
+ "Lingyu": ("zh_6.wav", "Chinese female voice C"),
32
+ "Trump": ("en_1.wav", "Trump reference voice"),
33
+ "Ava": ("en_2.wav", "English female voice A"),
34
+ "Bella": ("en_3.wav", "English female voice B"),
35
+ "Adam": ("en_4.wav", "English male voice A"),
36
+ "Nathan": ("en_5.wav", "English male voice B"),
37
+ "Sakura": ("jp_1.mp3", "Japanese female voice A"),
38
+ "Yui": ("jp_2.wav", "Japanese female voice B"),
39
+ "Aoi": ("jp_3.wav", "Japanese female voice C"),
40
+ "Hina": ("jp_4.wav", "Japanese female voice D"),
41
+ "Mei": ("jp_5.wav", "Japanese female voice E"),
42
+ }
43
+
44
+ DEFAULT_VOICE = "Junhao"
45
+ FLASH_ATTENTION_DTYPES = {torch.float16, torch.bfloat16}
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class VoicePreset:
50
+ name: str
51
+ prompt_audio_path: Path
52
+ description: str
53
+
54
+
55
+ def build_default_voice_presets() -> dict[str, VoicePreset]:
56
+ presets: dict[str, VoicePreset] = {}
57
+ for voice_name, (file_name, description) in _DEFAULT_VOICE_FILES.items():
58
+ prompt_path = (DEFAULT_PROMPT_AUDIO_DIR / file_name).resolve()
59
+ presets[voice_name] = VoicePreset(
60
+ name=voice_name,
61
+ prompt_audio_path=prompt_path,
62
+ description=description,
63
+ )
64
+ return presets
65
+
66
+
67
+ def resolve_device(device_arg: str) -> torch.device:
68
+ if device_arg == "auto":
69
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ return torch.device(device_arg)
71
+
72
+
73
+ def resolve_dtype(dtype_arg: str, device: torch.device) -> torch.dtype:
74
+ if dtype_arg == "float32":
75
+ return torch.float32
76
+ if dtype_arg == "float16":
77
+ return torch.float16
78
+ if dtype_arg == "bfloat16":
79
+ return torch.bfloat16
80
+ if device.type == "cuda":
81
+ if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
82
+ return torch.bfloat16
83
+ return torch.float16
84
+ return torch.float32
85
+
86
+
87
+ def waveform_to_numpy(waveform: torch.Tensor | np.ndarray) -> np.ndarray:
88
+ if torch.is_tensor(waveform):
89
+ array = waveform.detach().cpu().numpy()
90
+ else:
91
+ array = np.asarray(waveform)
92
+ if array.ndim == 1:
93
+ return array.astype(np.float32, copy=False)
94
+ if array.ndim != 2:
95
+ raise ValueError(f"Unsupported waveform shape: {array.shape}")
96
+ if array.shape[0] <= 8 and array.shape[0] < array.shape[1]:
97
+ array = array.T
98
+ return array.astype(np.float32, copy=False)
99
+
100
+
101
+ @lru_cache(maxsize=1)
102
+ def _has_flash_attn() -> bool:
103
+ try:
104
+ importlib.import_module("flash_attn")
105
+ except Exception:
106
+ return False
107
+ return True
108
+
109
+
110
+ class NanoTTSService:
111
+ def __init__(
112
+ self,
113
+ *,
114
+ checkpoint_path: str | Path = DEFAULT_CHECKPOINT_PATH,
115
+ audio_tokenizer_path: str | Path = DEFAULT_AUDIO_TOKENIZER_PATH,
116
+ device: str = "auto",
117
+ dtype: str = "auto",
118
+ attn_implementation: str = "auto",
119
+ output_dir: str | Path = DEFAULT_OUTPUT_DIR,
120
+ voice_presets: Optional[dict[str, VoicePreset]] = None,
121
+ ) -> None:
122
+ self.checkpoint_path = Path(checkpoint_path).expanduser().resolve()
123
+ self.audio_tokenizer_path = Path(audio_tokenizer_path).expanduser().resolve()
124
+ self.output_dir = Path(output_dir).expanduser().resolve()
125
+ self.output_dir.mkdir(parents=True, exist_ok=True)
126
+
127
+ self.voice_presets = voice_presets or build_default_voice_presets()
128
+ self.default_voice = DEFAULT_VOICE if DEFAULT_VOICE in self.voice_presets else next(iter(self.voice_presets))
129
+
130
+ self.device = resolve_device(device)
131
+ self.dtype = resolve_dtype(dtype, self.device)
132
+ self.attn_implementation = self._resolve_attn_implementation(attn_implementation)
133
+
134
+ self._lock = threading.RLock()
135
+ self._model = None
136
+ self._audio_tokenizer = None
137
+ self._checkpoint_global_attn_implementation: str | None = None
138
+ self._checkpoint_local_attn_implementation: str | None = None
139
+ self._configured_global_attn_implementation: str | None = None
140
+ self._configured_local_attn_implementation: str | None = None
141
+ self._configured_audio_tokenizer_attn_implementation: str | None = None
142
+ self._configured_audio_tokenizer_compute_dtype: str | None = None
143
+
144
+ def _can_use_flash_attention(self) -> bool:
145
+ return self.device.type == "cuda" and self.dtype in FLASH_ATTENTION_DTYPES and _has_flash_attn()
146
+
147
+ def _resolve_runtime_default_attn_implementation(self) -> str:
148
+ return "flash_attention_2" if self._can_use_flash_attention() else "sdpa"
149
+
150
+ def _resolve_attn_implementation(self, requested: str | None) -> str | None:
151
+ normalized = str(requested).strip().lower() if requested is not None else "auto"
152
+ if not normalized or normalized in {"auto", "default", "model_default"}:
153
+ return None
154
+ if normalized not in {"sdpa", "flash_attention_2", "eager"}:
155
+ raise ValueError(
156
+ "attn_implementation must be one of: model_default/auto, sdpa, flash_attention_2, eager"
157
+ )
158
+ if normalized == "flash_attention_2":
159
+ if not self._can_use_flash_attention():
160
+ logging.warning(
161
+ "flash_attention_2 requires CUDA, flash_attn, and fp16/bf16; falling back to sdpa "
162
+ "(device=%s dtype=%s flash_attn=%s)",
163
+ self.device,
164
+ self.dtype,
165
+ _has_flash_attn(),
166
+ )
167
+ return "sdpa"
168
+ return normalized
169
+
170
+ @staticmethod
171
+ def _normalize_loaded_attn_implementation(attn_implementation: object) -> str:
172
+ normalized = str(attn_implementation).strip().lower() if attn_implementation is not None else ""
173
+ if not normalized or normalized == "none":
174
+ return "eager"
175
+ return normalized
176
+
177
+ def _resolve_request_attention_implementation(
178
+ self,
179
+ requested: Optional[str],
180
+ ) -> tuple[str, str, str]:
181
+ normalized = str(requested).strip().lower() if requested is not None else ""
182
+ resolved = self._resolve_attn_implementation(normalized)
183
+ if resolved is not None:
184
+ return normalized, resolved, resolved
185
+
186
+ if self.attn_implementation is not None:
187
+ return self.attn_implementation, self.attn_implementation, self.attn_implementation
188
+
189
+ runtime_default = self._resolve_runtime_default_attn_implementation()
190
+ return "auto", runtime_default, runtime_default
191
+
192
+ @staticmethod
193
+ def _resolve_codec_attention_implementation(tts_attn_implementation: str) -> str:
194
+ return "flash_attention_2" if tts_attn_implementation == "flash_attention_2" else "sdpa"
195
+
196
+ def _resolve_codec_compute_dtype(self, codec_attn_implementation: str) -> str:
197
+ if codec_attn_implementation == "flash_attention_2":
198
+ return "bf16" if self.dtype == torch.bfloat16 else "fp16"
199
+ return "fp32"
200
+
201
+ @staticmethod
202
+ def _apply_model_attention_implementation(model, *, global_attn: str, local_attn: str) -> None:
203
+ if hasattr(model, "_set_attention_implementation"):
204
+ model._set_attention_implementation(global_attn, local_attn_implementation=local_attn)
205
+
206
+ def _install_stream_decode_budget_patch(self, model) -> None:
207
+ if self.device.type != "cuda":
208
+ return
209
+
210
+ model_cls = model.__class__
211
+ if getattr(model_cls, "_nanotts_stream_decode_budget_patch_installed", False):
212
+ return
213
+
214
+ compute_stream_lead = getattr(model_cls, "_compute_stream_lead_seconds", None)
215
+ resolve_stream_budget = getattr(model_cls, "_resolve_stream_decode_frame_budget", None)
216
+ if not callable(compute_stream_lead) or not callable(resolve_stream_budget):
217
+ return
218
+
219
+ def _patched_resolve_stream_decode_frame_budget(
220
+ *,
221
+ emitted_samples_total: int,
222
+ sample_rate: int,
223
+ first_audio_emitted_at,
224
+ ) -> int:
225
+ # The upstream streaming policy starts with one decode frame
226
+ # (about 80 ms audio), which makes CUDA realtime decode emit many
227
+ # tiny chunks and underrun browser playback on this checkpoint.
228
+ lead_seconds = compute_stream_lead(
229
+ emitted_samples_total=emitted_samples_total,
230
+ sample_rate=sample_rate,
231
+ first_audio_emitted_at=first_audio_emitted_at,
232
+ )
233
+ if first_audio_emitted_at is None or lead_seconds < 0.20:
234
+ return 4
235
+ if lead_seconds < 0.55:
236
+ return 6
237
+ if lead_seconds < 1.10:
238
+ return 8
239
+ return 12
240
+
241
+ model_cls._nanotts_original_resolve_stream_decode_frame_budget = resolve_stream_budget
242
+ model_cls._resolve_stream_decode_frame_budget = staticmethod(_patched_resolve_stream_decode_frame_budget)
243
+ model_cls._nanotts_stream_decode_budget_patch_installed = True
244
+ logging.info("installed Nano-TTS CUDA streaming decode budget patch")
245
+
246
+ def _discard_loaded_model_locked(self, reason: str) -> None:
247
+ if self._model is None:
248
+ return
249
+ logging.warning("discarding loaded Nano-TTS model state: %s", reason)
250
+ self._model = None
251
+ if self.device.type == "cuda":
252
+ torch.cuda.empty_cache()
253
+
254
+ def _discard_loaded_audio_tokenizer_locked(self, reason: str) -> None:
255
+ if self._audio_tokenizer is None:
256
+ return
257
+ logging.warning("discarding loaded Nano-TTS audio tokenizer state: %s", reason)
258
+ self._audio_tokenizer = None
259
+ self._configured_audio_tokenizer_attn_implementation = None
260
+ self._configured_audio_tokenizer_compute_dtype = None
261
+ if self.device.type == "cuda":
262
+ torch.cuda.empty_cache()
263
+
264
+ def _restore_model_execution_state(self, model):
265
+ current_parameter = next(model.parameters(), None)
266
+ if current_parameter is None or current_parameter.dtype == self.dtype:
267
+ return model
268
+ self._discard_loaded_model_locked(
269
+ f"current_dtype={current_parameter.dtype} expected_dtype={self.dtype}; reloading checkpoint"
270
+ )
271
+ return self._load_model_locked()
272
+
273
+ def _read_model_attention_implementation(self, model) -> tuple[str, str]:
274
+ global_attn = self._normalize_loaded_attn_implementation(
275
+ getattr(getattr(model, "transformer", None), "attn_implementation", None)
276
+ )
277
+ local_attn = self._normalize_loaded_attn_implementation(
278
+ getattr(getattr(model, "local_transformer", None), "attn_implementation", None)
279
+ )
280
+ return global_attn, local_attn
281
+
282
+ def _ensure_paths(self) -> None:
283
+ if not self.checkpoint_path.exists():
284
+ raise FileNotFoundError(f"Nano-TTS checkpoint not found: {self.checkpoint_path}")
285
+ if not self.audio_tokenizer_path.exists():
286
+ raise FileNotFoundError(f"Audio tokenizer checkpoint not found: {self.audio_tokenizer_path}")
287
+
288
+ def _load_audio_tokenizer_locked(self, *, tts_attn_implementation: str):
289
+ codec_attn_implementation = self._resolve_codec_attention_implementation(tts_attn_implementation)
290
+ codec_compute_dtype = self._resolve_codec_compute_dtype(codec_attn_implementation)
291
+
292
+ if self._audio_tokenizer is None:
293
+ logging.info(
294
+ "loading Nano-TTS audio tokenizer checkpoint=%s device=%s attn=%s compute_dtype=%s",
295
+ self.audio_tokenizer_path,
296
+ self.device,
297
+ codec_attn_implementation,
298
+ codec_compute_dtype,
299
+ )
300
+ audio_tokenizer = AutoModel.from_pretrained(
301
+ str(self.audio_tokenizer_path),
302
+ trust_remote_code=True,
303
+ local_files_only=True,
304
+ )
305
+ if hasattr(audio_tokenizer, "eval"):
306
+ audio_tokenizer.eval()
307
+ self._audio_tokenizer = audio_tokenizer
308
+
309
+ audio_tokenizer = self._audio_tokenizer
310
+ if hasattr(audio_tokenizer, "to"):
311
+ audio_tokenizer = audio_tokenizer.to(self.device)
312
+ if hasattr(audio_tokenizer, "set_attention_implementation"):
313
+ audio_tokenizer.set_attention_implementation(codec_attn_implementation)
314
+ if hasattr(audio_tokenizer, "set_compute_dtype"):
315
+ audio_tokenizer.set_compute_dtype(codec_compute_dtype)
316
+ if hasattr(audio_tokenizer, "eval"):
317
+ audio_tokenizer.eval()
318
+
319
+ self._audio_tokenizer = audio_tokenizer
320
+ self._configured_audio_tokenizer_attn_implementation = codec_attn_implementation
321
+ self._configured_audio_tokenizer_compute_dtype = codec_compute_dtype
322
+ return self._audio_tokenizer
323
+
324
+ def _load_model_locked(self):
325
+ if self._model is not None:
326
+ return self._model
327
+
328
+ self._ensure_paths()
329
+ logging.info(
330
+ "loading Nano-TTS checkpoint=%s audio_tokenizer=%s device=%s dtype=%s attn=%s",
331
+ self.checkpoint_path,
332
+ self.audio_tokenizer_path,
333
+ self.device,
334
+ self.dtype,
335
+ self.attn_implementation or "model_default",
336
+ )
337
+ model = AutoModelForCausalLM.from_pretrained(
338
+ str(self.checkpoint_path),
339
+ trust_remote_code=True,
340
+ local_files_only=True,
341
+ )
342
+ model.to(device=self.device, dtype=self.dtype)
343
+ self._checkpoint_global_attn_implementation, self._checkpoint_local_attn_implementation = (
344
+ self._read_model_attention_implementation(model)
345
+ )
346
+ _, default_global_attn, default_local_attn = self._resolve_request_attention_implementation(None)
347
+ self._apply_model_attention_implementation(
348
+ model,
349
+ global_attn=default_global_attn,
350
+ local_attn=default_local_attn,
351
+ )
352
+ self._install_stream_decode_budget_patch(model)
353
+ model.eval()
354
+ self._configured_global_attn_implementation, self._configured_local_attn_implementation = (
355
+ self._read_model_attention_implementation(model)
356
+ )
357
+ self._model = model
358
+ return self._model
359
+
360
+ def get_model(self):
361
+ with self._lock:
362
+ return self._load_model_locked()
363
+
364
+ def list_voice_names(self) -> list[str]:
365
+ return list(self.voice_presets.keys())
366
+
367
+ def get_voice_preset(self, voice_name: Optional[str]) -> VoicePreset:
368
+ if voice_name and voice_name in self.voice_presets:
369
+ return self.voice_presets[voice_name]
370
+ return self.voice_presets[self.default_voice]
371
+
372
+ def resolve_prompt_audio_path(
373
+ self,
374
+ *,
375
+ voice: Optional[str] = None,
376
+ prompt_audio_path: Optional[str | Path] = None,
377
+ ) -> Path:
378
+ if prompt_audio_path:
379
+ resolved = Path(prompt_audio_path).expanduser().resolve()
380
+ if not resolved.exists():
381
+ raise FileNotFoundError(f"Prompt audio not found: {resolved}")
382
+ return resolved
383
+
384
+ preset = self.get_voice_preset(voice)
385
+ if not preset.prompt_audio_path.exists():
386
+ raise FileNotFoundError(f"Voice preset prompt audio not found: {preset.prompt_audio_path}")
387
+ return preset.prompt_audio_path
388
+
389
+ def preload(self, *, voices: Optional[list[str]] = None, load_model: bool = True) -> dict[str, object]:
390
+ loaded_voices: list[str] = []
391
+ if load_model:
392
+ self.get_model()
393
+ for voice_name in voices or [self.default_voice]:
394
+ preset = self.get_voice_preset(voice_name)
395
+ if preset.prompt_audio_path.exists():
396
+ loaded_voices.append(preset.name)
397
+ return {
398
+ "loaded_voices": loaded_voices,
399
+ "device": str(self.device),
400
+ "dtype": str(self.dtype),
401
+ "attn_implementation": self.attn_implementation or "auto",
402
+ "checkpoint_default_attn_implementation": self._checkpoint_global_attn_implementation or "eager",
403
+ "checkpoint_default_local_attn_implementation": self._checkpoint_local_attn_implementation or "eager",
404
+ "configured_attn_implementation": self._configured_global_attn_implementation or "eager",
405
+ "configured_local_attn_implementation": self._configured_local_attn_implementation or "eager",
406
+ "configured_codec_attn_implementation": self._configured_audio_tokenizer_attn_implementation or "unknown",
407
+ "configured_codec_compute_dtype": self._configured_audio_tokenizer_compute_dtype or "unknown",
408
+ }
409
+
410
+ def _build_output_path(self, prefix: str) -> Path:
411
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
412
+ random_suffix = uuid.uuid4().hex[:8]
413
+ return self.output_dir / f"{prefix}_{timestamp}_{random_suffix}.wav"
414
+
415
+ def synthesize(
416
+ self,
417
+ *,
418
+ text: str,
419
+ voice: Optional[str] = None,
420
+ mode: str = "voice_clone",
421
+ output_audio_path: Optional[str | Path] = None,
422
+ prompt_audio_path: Optional[str | Path] = None,
423
+ prompt_text: Optional[str] = None,
424
+ max_new_frames: int = 375,
425
+ voice_clone_max_text_tokens: int = 75,
426
+ voice_clone_max_memory_per_sample_gb: float = 1.0,
427
+ tts_max_batch_size: int = 0,
428
+ codec_max_batch_size: int = 0,
429
+ do_sample: bool = True,
430
+ text_temperature: float = 1.0,
431
+ text_top_p: float = 1.0,
432
+ text_top_k: int = 50,
433
+ audio_temperature: float = 0.8,
434
+ audio_top_p: float = 0.95,
435
+ audio_top_k: int = 25,
436
+ audio_repetition_penalty: float = 1.2,
437
+ nq: Optional[int] = None,
438
+ seed: Optional[int] = None,
439
+ attn_implementation: Optional[str] = None,
440
+ ) -> dict[str, object]:
441
+ normalized_text = str(text or "").strip()
442
+ if not normalized_text:
443
+ raise ValueError("text is required")
444
+
445
+ normalized_mode = str(mode).strip().lower()
446
+ if normalized_mode not in {"continuation", "voice_clone"}:
447
+ raise ValueError("mode must be either 'continuation' or 'voice_clone'")
448
+
449
+ effective_prompt_audio_path: Optional[Path] = None
450
+ resolved_voice = self.get_voice_preset(voice).name
451
+ if normalized_mode == "voice_clone":
452
+ effective_prompt_audio_path = self.resolve_prompt_audio_path(
453
+ voice=resolved_voice,
454
+ prompt_audio_path=prompt_audio_path,
455
+ )
456
+ elif prompt_audio_path is not None:
457
+ effective_prompt_audio_path = self.resolve_prompt_audio_path(prompt_audio_path=prompt_audio_path)
458
+ if not prompt_text:
459
+ raise ValueError("continuation mode with prompt_audio_path also requires prompt_text")
460
+
461
+ output_path = (
462
+ Path(output_audio_path).expanduser().resolve()
463
+ if output_audio_path is not None
464
+ else self._build_output_path(prefix=f"{resolved_voice}_{normalized_mode}")
465
+ )
466
+ output_path.parent.mkdir(parents=True, exist_ok=True)
467
+
468
+ started_at = time.monotonic()
469
+ with self._lock:
470
+ model = self._load_model_locked()
471
+ model = self._restore_model_execution_state(model)
472
+ requested_attn_implementation, effective_global_attn_implementation, effective_local_attn_implementation = (
473
+ self._resolve_request_attention_implementation(attn_implementation)
474
+ )
475
+ audio_tokenizer = self._load_audio_tokenizer_locked(
476
+ tts_attn_implementation=effective_global_attn_implementation
477
+ )
478
+ self._apply_model_attention_implementation(
479
+ model,
480
+ global_attn=effective_global_attn_implementation,
481
+ local_attn=effective_local_attn_implementation,
482
+ )
483
+ if seed is not None:
484
+ torch.manual_seed(seed)
485
+ if torch.cuda.is_available():
486
+ torch.cuda.manual_seed_all(seed)
487
+
488
+ try:
489
+ result = model.inference(
490
+ text=normalized_text,
491
+ output_audio_path=str(output_path),
492
+ mode=normalized_mode,
493
+ prompt_text=prompt_text,
494
+ prompt_audio_path=None if effective_prompt_audio_path is None else str(effective_prompt_audio_path),
495
+ text_tokenizer_path=str(self.checkpoint_path),
496
+ audio_tokenizer=audio_tokenizer,
497
+ device=self.device,
498
+ nq=nq,
499
+ max_new_frames=int(max_new_frames),
500
+ voice_clone_max_text_tokens=int(voice_clone_max_text_tokens),
501
+ voice_clone_max_memory_per_sample_gb=float(voice_clone_max_memory_per_sample_gb),
502
+ tts_max_batch_size=int(tts_max_batch_size),
503
+ codec_max_batch_size=int(codec_max_batch_size),
504
+ do_sample=bool(do_sample),
505
+ use_kv_cache=True,
506
+ text_temperature=float(text_temperature),
507
+ text_top_p=float(text_top_p),
508
+ text_top_k=int(text_top_k),
509
+ audio_temperature=float(audio_temperature),
510
+ audio_top_p=float(audio_top_p),
511
+ audio_top_k=int(audio_top_k),
512
+ audio_repetition_penalty=float(audio_repetition_penalty),
513
+ )
514
+ except Exception:
515
+ self._discard_loaded_audio_tokenizer_locked(
516
+ "inference failed; reloading audio tokenizer on next request"
517
+ )
518
+ self._discard_loaded_model_locked("inference failed; reloading checkpoint on next request")
519
+ raise
520
+ effective_global_attn_implementation, effective_local_attn_implementation = (
521
+ self._read_model_attention_implementation(model)
522
+ )
523
+ current_parameter = next(model.parameters(), None)
524
+ if current_parameter is not None and current_parameter.dtype != self.dtype:
525
+ self._discard_loaded_model_locked(
526
+ f"inference left model in dtype={current_parameter.dtype}; reloading checkpoint on next request"
527
+ )
528
+
529
+ waveform = result["waveform"].detach().cpu()
530
+ waveform_numpy = waveform_to_numpy(waveform)
531
+ return {
532
+ "audio_path": str(output_path),
533
+ "sample_rate": int(result["sample_rate"]),
534
+ "waveform": waveform,
535
+ "waveform_numpy": waveform_numpy,
536
+ "audio_token_ids": result["audio_token_ids"],
537
+ "reference_audio_token_ids": result["reference_audio_token_ids"],
538
+ "elapsed_seconds": time.monotonic() - started_at,
539
+ "voice": resolved_voice,
540
+ "mode": normalized_mode,
541
+ "prompt_audio_path": None if effective_prompt_audio_path is None else str(effective_prompt_audio_path),
542
+ "requested_attn_implementation": requested_attn_implementation,
543
+ "effective_global_attn_implementation": effective_global_attn_implementation,
544
+ "effective_local_attn_implementation": effective_local_attn_implementation,
545
+ "voice_clone_text_chunks": result.get("voice_clone_text_chunks"),
546
+ "voice_clone_chunk_batch_size": result.get("voice_clone_chunk_batch_size"),
547
+ "voice_clone_codec_batch_size": result.get("voice_clone_codec_batch_size"),
548
+ }
549
+
550
+ def synthesize_stream(
551
+ self,
552
+ *,
553
+ text: str,
554
+ voice: Optional[str] = None,
555
+ mode: str = "voice_clone",
556
+ output_audio_path: Optional[str | Path] = None,
557
+ prompt_audio_path: Optional[str | Path] = None,
558
+ prompt_text: Optional[str] = None,
559
+ max_new_frames: int = 375,
560
+ voice_clone_max_text_tokens: int = 75,
561
+ voice_clone_max_memory_per_sample_gb: float = 1.0,
562
+ tts_max_batch_size: int = 0,
563
+ codec_max_batch_size: int = 0,
564
+ do_sample: bool = True,
565
+ text_temperature: float = 1.0,
566
+ text_top_p: float = 1.0,
567
+ text_top_k: int = 50,
568
+ audio_temperature: float = 0.8,
569
+ audio_top_p: float = 0.95,
570
+ audio_top_k: int = 25,
571
+ audio_repetition_penalty: float = 1.2,
572
+ nq: Optional[int] = None,
573
+ seed: Optional[int] = None,
574
+ attn_implementation: Optional[str] = None,
575
+ ) -> Iterator[dict[str, object]]:
576
+ normalized_text = str(text or "").strip()
577
+ if not normalized_text:
578
+ raise ValueError("text is required")
579
+
580
+ normalized_mode = str(mode).strip().lower()
581
+ if normalized_mode not in {"continuation", "voice_clone"}:
582
+ raise ValueError("mode must be either 'continuation' or 'voice_clone'")
583
+
584
+ effective_prompt_audio_path: Optional[Path] = None
585
+ resolved_voice = self.get_voice_preset(voice).name
586
+ if normalized_mode == "voice_clone":
587
+ effective_prompt_audio_path = self.resolve_prompt_audio_path(
588
+ voice=resolved_voice,
589
+ prompt_audio_path=prompt_audio_path,
590
+ )
591
+ elif prompt_audio_path is not None:
592
+ effective_prompt_audio_path = self.resolve_prompt_audio_path(prompt_audio_path=prompt_audio_path)
593
+ if not prompt_text:
594
+ raise ValueError("continuation mode with prompt_audio_path also requires prompt_text")
595
+
596
+ output_path = (
597
+ Path(output_audio_path).expanduser().resolve()
598
+ if output_audio_path is not None
599
+ else self._build_output_path(prefix=f"{resolved_voice}_{normalized_mode}_stream")
600
+ )
601
+ output_path.parent.mkdir(parents=True, exist_ok=True)
602
+
603
+ started_at = time.monotonic()
604
+ final_result: dict[str, object] | None = None
605
+ with self._lock:
606
+ model = self._load_model_locked()
607
+ model = self._restore_model_execution_state(model)
608
+ requested_attn_implementation, effective_global_attn_implementation, effective_local_attn_implementation = (
609
+ self._resolve_request_attention_implementation(attn_implementation)
610
+ )
611
+ audio_tokenizer = self._load_audio_tokenizer_locked(
612
+ tts_attn_implementation=effective_global_attn_implementation
613
+ )
614
+ self._apply_model_attention_implementation(
615
+ model,
616
+ global_attn=effective_global_attn_implementation,
617
+ local_attn=effective_local_attn_implementation,
618
+ )
619
+ if seed is not None:
620
+ torch.manual_seed(seed)
621
+ if torch.cuda.is_available():
622
+ torch.cuda.manual_seed_all(seed)
623
+
624
+ try:
625
+ for event in model.inference_stream(
626
+ text=normalized_text,
627
+ output_audio_path=str(output_path),
628
+ mode=normalized_mode,
629
+ prompt_text=prompt_text,
630
+ prompt_audio_path=None if effective_prompt_audio_path is None else str(effective_prompt_audio_path),
631
+ text_tokenizer_path=str(self.checkpoint_path),
632
+ audio_tokenizer=audio_tokenizer,
633
+ device=self.device,
634
+ nq=nq,
635
+ max_new_frames=int(max_new_frames),
636
+ voice_clone_max_text_tokens=int(voice_clone_max_text_tokens),
637
+ voice_clone_max_memory_per_sample_gb=float(voice_clone_max_memory_per_sample_gb),
638
+ tts_max_batch_size=int(tts_max_batch_size),
639
+ codec_max_batch_size=int(codec_max_batch_size),
640
+ do_sample=bool(do_sample),
641
+ use_kv_cache=True,
642
+ text_temperature=float(text_temperature),
643
+ text_top_p=float(text_top_p),
644
+ text_top_k=int(text_top_k),
645
+ audio_temperature=float(audio_temperature),
646
+ audio_top_p=float(audio_top_p),
647
+ audio_top_k=int(audio_top_k),
648
+ audio_repetition_penalty=float(audio_repetition_penalty),
649
+ ):
650
+ event_type = str(event.get("type", ""))
651
+ if event_type == "audio":
652
+ waveform = torch.as_tensor(event["waveform"], dtype=torch.float32).cpu()
653
+ yield {
654
+ "type": "audio",
655
+ "waveform": waveform,
656
+ "waveform_numpy": waveform_to_numpy(waveform),
657
+ "sample_rate": int(event["sample_rate"]),
658
+ "chunk_index": int(event.get("chunk_index", 0)),
659
+ "is_pause": bool(event.get("is_pause", False)),
660
+ "emitted_audio_seconds": float(event.get("emitted_audio_seconds", 0.0)),
661
+ "lead_seconds": float(event.get("lead_seconds", 0.0)),
662
+ }
663
+ continue
664
+ if event_type == "result":
665
+ final_result = dict(event)
666
+ except Exception:
667
+ self._discard_loaded_audio_tokenizer_locked(
668
+ "streaming inference failed; reloading audio tokenizer on next request"
669
+ )
670
+ self._discard_loaded_model_locked("streaming inference failed; reloading checkpoint on next request")
671
+ raise
672
+
673
+ effective_global_attn_implementation, effective_local_attn_implementation = (
674
+ self._read_model_attention_implementation(model)
675
+ )
676
+ current_parameter = next(model.parameters(), None)
677
+ if current_parameter is not None and current_parameter.dtype != self.dtype:
678
+ self._discard_loaded_model_locked(
679
+ f"streaming inference left model in dtype={current_parameter.dtype}; reloading checkpoint on next request"
680
+ )
681
+
682
+ if final_result is None:
683
+ raise RuntimeError("Streaming synthesis finished without a final result.")
684
+
685
+ waveform = torch.as_tensor(final_result["waveform"], dtype=torch.float32).cpu()
686
+ yield {
687
+ "type": "result",
688
+ "audio_path": str(final_result["audio_path"]),
689
+ "sample_rate": int(final_result["sample_rate"]),
690
+ "waveform": waveform,
691
+ "waveform_numpy": waveform_to_numpy(waveform),
692
+ "audio_token_ids": final_result["audio_token_ids"],
693
+ "reference_audio_token_ids": final_result["reference_audio_token_ids"],
694
+ "elapsed_seconds": time.monotonic() - started_at,
695
+ "voice": resolved_voice,
696
+ "mode": normalized_mode,
697
+ "prompt_audio_path": None if effective_prompt_audio_path is None else str(effective_prompt_audio_path),
698
+ "requested_attn_implementation": requested_attn_implementation,
699
+ "effective_global_attn_implementation": effective_global_attn_implementation,
700
+ "effective_local_attn_implementation": effective_local_attn_implementation,
701
+ "voice_clone_text_chunks": final_result.get("voice_clone_text_chunks"),
702
+ "voice_clone_chunk_batch_size": final_result.get("voice_clone_chunk_batch_size"),
703
+ "voice_clone_codec_batch_size": final_result.get("voice_clone_codec_batch_size"),
704
+ }
705
+
706
+ def warmup(
707
+ self,
708
+ *,
709
+ text: str = "你好,欢迎使用 Nano-TTS。",
710
+ voice: Optional[str] = None,
711
+ ) -> dict[str, object]:
712
+ return self.synthesize(
713
+ text=text,
714
+ voice=voice or self.default_voice,
715
+ mode="voice_clone",
716
+ output_audio_path=self.output_dir / "_warmup" / "warmup.wav",
717
+ max_new_frames=96,
718
+ voice_clone_max_text_tokens=75,
719
+ do_sample=False,
720
+ text_temperature=1.0,
721
+ text_top_p=1.0,
722
+ text_top_k=50,
723
+ audio_temperature=0.8,
724
+ audio_top_p=0.95,
725
+ audio_top_k=25,
726
+ audio_repetition_penalty=1.0,
727
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy>=1.24
2
+ sentencepiece>=0.1.99
3
+ torch==2.7.0
4
+ torchaudio==2.7.0
5
+ transformers==4.57.1
6
+ safetensors>=0.4.3
7
+ gradio==6.5.1
text_normalization_pipeline.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import re
5
+ import threading
6
+ from dataclasses import dataclass
7
+
8
+ from tts_robust_normalizer_single_script import normalize_tts_text
9
+
10
+ ENGLISH_VOICES = frozenset({"Trump", "Ava", "Bella", "Adam", "Nathan"})
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TextNormalizationSnapshot:
15
+ state: str
16
+ message: str
17
+ error: str | None = None
18
+ ready: bool = False
19
+ available: bool = False
20
+
21
+ @property
22
+ def failed(self) -> bool:
23
+ return self.state == "failed"
24
+
25
+
26
+ class WeTextProcessingManager:
27
+ def __init__(self) -> None:
28
+ self._lock = threading.Lock()
29
+ self._normalize_lock = threading.Lock()
30
+ self._thread: threading.Thread | None = None
31
+ self._started = False
32
+ self._state = "pending"
33
+ self._message = "Waiting for WeTextProcessing preload."
34
+ self._error: str | None = None
35
+ self._available = True
36
+ self._normalizers: dict[str, object] | None = None
37
+
38
+ def snapshot(self) -> TextNormalizationSnapshot:
39
+ with self._lock:
40
+ return TextNormalizationSnapshot(
41
+ state=self._state,
42
+ message=self._message,
43
+ error=self._error,
44
+ ready=self._state == "ready",
45
+ available=self._available,
46
+ )
47
+
48
+ def _set_state(self, *, state: str, message: str, error: str | None = None) -> None:
49
+ with self._lock:
50
+ self._state = state
51
+ self._message = message
52
+ self._error = error
53
+
54
+ def start(self) -> None:
55
+ with self._lock:
56
+ if self._started:
57
+ return
58
+ self._started = True
59
+ self._thread = threading.Thread(target=self._run, name="wetext-preload", daemon=True)
60
+ self._thread.start()
61
+
62
+ def ensure_ready(self) -> TextNormalizationSnapshot:
63
+ with self._lock:
64
+ if not self._started:
65
+ self._started = True
66
+ self._thread = threading.Thread(target=self._run, name="wetext-preload", daemon=True)
67
+ self._thread.start()
68
+ thread = self._thread
69
+ if thread is not None and thread.is_alive():
70
+ thread.join()
71
+ return self.snapshot()
72
+
73
+ def close(self) -> None:
74
+ return
75
+
76
+ def _run(self) -> None:
77
+ if not self._available:
78
+ self._set_state(
79
+ state="failed",
80
+ message="WeTextProcessing unavailable.",
81
+ error="installed WeTextProcessing modules are unavailable",
82
+ )
83
+ return
84
+ try:
85
+ self._set_state(state="running", message="Loading WeTextProcessing graphs.", error=None)
86
+ self._ensure_normalizers_loaded()
87
+ self._set_state(state="ready", message="WeTextProcessing ready. languages=zh,en", error=None)
88
+ except Exception as exc:
89
+ logging.exception("WeTextProcessing preload failed")
90
+ self._set_state(state="failed", message="WeTextProcessing preload failed.", error=str(exc))
91
+
92
+ def _ensure_normalizers_loaded(self) -> dict[str, object]:
93
+ with self._lock:
94
+ if self._normalizers is not None:
95
+ return self._normalizers
96
+
97
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
98
+ from tn.english.normalizer import Normalizer as EnNormalizer
99
+
100
+ logging.getLogger().setLevel(logging.INFO)
101
+ self._normalizers = {
102
+ "zh": ZhNormalizer(overwrite_cache=False),
103
+ "en": EnNormalizer(overwrite_cache=False),
104
+ }
105
+ return self._normalizers
106
+
107
+ def normalize(self, *, text: str, prompt_text: str, language: str) -> tuple[str, str]:
108
+ snapshot = self.ensure_ready()
109
+ if not snapshot.ready:
110
+ raise RuntimeError(snapshot.error or snapshot.message)
111
+
112
+ with self._normalize_lock:
113
+ normalizers = self._ensure_normalizers_loaded()
114
+ if language not in normalizers:
115
+ raise ValueError(f"Unsupported text normalization language: {language}")
116
+ normalizer = normalizers[language]
117
+ normalized_text = normalizer.normalize(text) if text else ""
118
+ normalized_prompt_text = normalizer.normalize(prompt_text) if prompt_text else ""
119
+ return normalized_text, normalized_prompt_text
120
+
121
+
122
+ def resolve_text_normalization_language(*, text: str, voice: str) -> str:
123
+ if re.search(r"[\u3400-\u9fff]", text):
124
+ return "zh"
125
+ if re.search(r"[A-Za-z]", text):
126
+ return "en"
127
+ if voice in ENGLISH_VOICES:
128
+ return "en"
129
+ return "zh"
130
+
131
+
132
+ def prepare_tts_request_texts(
133
+ *,
134
+ text: str,
135
+ prompt_text: str,
136
+ voice: str,
137
+ enable_wetext: bool,
138
+ text_normalizer_manager: WeTextProcessingManager | None,
139
+ ) -> dict[str, object]:
140
+ raw_text = str(text or "")
141
+ raw_prompt_text = str(prompt_text or "")
142
+
143
+ normalization_language = ""
144
+ intermediate_text = raw_text
145
+ intermediate_prompt_text = raw_prompt_text
146
+
147
+ if enable_wetext:
148
+ if text_normalizer_manager is None:
149
+ raise RuntimeError("WeTextProcessing manager is unavailable.")
150
+ normalization_language = resolve_text_normalization_language(text=raw_text, voice=voice)
151
+ intermediate_text, intermediate_prompt_text = text_normalizer_manager.normalize(
152
+ text=raw_text,
153
+ prompt_text=raw_prompt_text,
154
+ language=normalization_language,
155
+ )
156
+ if intermediate_text != raw_text:
157
+ logging.info(
158
+ "normalized text chars_before=%d chars_after=%d stage=wetext language=%s",
159
+ len(raw_text),
160
+ len(intermediate_text),
161
+ normalization_language,
162
+ )
163
+ if raw_prompt_text and intermediate_prompt_text != raw_prompt_text:
164
+ logging.info(
165
+ "normalized prompt_text chars_before=%d chars_after=%d stage=wetext language=%s",
166
+ len(raw_prompt_text),
167
+ len(intermediate_prompt_text),
168
+ normalization_language,
169
+ )
170
+
171
+ final_text = normalize_tts_text(intermediate_text)
172
+ final_prompt_text = normalize_tts_text(intermediate_prompt_text) if intermediate_prompt_text else ""
173
+
174
+ if final_text != intermediate_text:
175
+ logging.info(
176
+ "normalized text chars_before=%d chars_after=%d stage=robust_final",
177
+ len(intermediate_text),
178
+ len(final_text),
179
+ )
180
+ if intermediate_prompt_text and final_prompt_text != intermediate_prompt_text:
181
+ logging.info(
182
+ "normalized prompt_text chars_before=%d chars_after=%d stage=robust_final",
183
+ len(intermediate_prompt_text),
184
+ len(final_prompt_text),
185
+ )
186
+
187
+ return {
188
+ "text": final_text,
189
+ "prompt_text": final_prompt_text,
190
+ "normalized_text": final_text,
191
+ "normalized_prompt_text": final_prompt_text,
192
+ "normalization_method": (f"wetext:{normalization_language}+robust" if enable_wetext else "robust"),
193
+ "text_normalization_language": normalization_language,
194
+ "text_normalization_enabled": bool(enable_wetext),
195
+ }
tts_robust_normalizer_single_script.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ TTS 输入鲁棒性正则化器(非语义 TN)
6
+
7
+ 目标
8
+ ----
9
+ 1. 只做“鲁棒性清洗”,不做数字/单位/日期/金额等语义展开。
10
+ 2. 优先保护高风险 token,避免把 `.map`、`app.js.map`、`v2.3.1`、URL、Email、@mention、#hashtag 清坏。
11
+ 3. `[]` / `{}` / `【】` / `〖〗` / `『』` / `「」` 统一转成双引号包裹内容。
12
+ 4. 对结构性符号做“替换而非删除”:
13
+ - `【】 / 〖〗 / 『』 / 「」` 转成双引号包裹内容。
14
+ - `《》` 只在“独立标题/栏目名”场景拆开;嵌入式标题保持不变。
15
+ - `—— / -- / ——...` 转成句边界。
16
+ 5. 对社交平台常见噪声做弱归一化:
17
+ - `...... / ……` -> `。`
18
+ - `???!!!` -> `?!`
19
+ - `!!!` -> `!`
20
+ 6. 空格按脚本类型处理:
21
+ - 西文片段内部:连续空格压缩为 1 个。
22
+ - 汉字 / 日文假名片段内部:删除空格。
23
+ - 汉字 / 日文假名 与“拉丁字母类 token / 受保护 token”相邻:保留或补 1 个空格。
24
+ - 汉字 / 日文假名 与纯数字相邻:不强行补空格。
25
+ 7. 轻量处理 Markdown 与换行:
26
+ - `[text](url)` -> `text url`
27
+ - 去掉标题 `#`、引用 `>`、列表前缀
28
+ - 换行转句边界 `。`
29
+
30
+ 非目标
31
+ ------
32
+ 1. 不决定“应该怎么读”。
33
+ 2. 不做 HTML/SSML/语义标签解释。
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import re
39
+ import unicodedata
40
+
41
+
42
+ # ---------------------------
43
+ # 基础常量与正则
44
+ # ---------------------------
45
+
46
+ # 不依赖空格分词的脚本:汉字 + 日文假名
47
+ _CJK_CHARS = r"\u3400-\u4dbf\u4e00-\u9fff\u3040-\u30ff"
48
+ _CJK = f"[{_CJK_CHARS}]"
49
+
50
+ # 保护占位符
51
+ _PROT = r"___PROT\d+___"
52
+
53
+ # 需要保护的高风险 token
54
+ _URL_RE = re.compile(r"https?://[^\s\u3000,。!?;、)】》〉」』]+")
55
+ _EMAIL_RE = re.compile(r"(?<![\w.+-])[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}(?![\w.-])")
56
+ _MENTION_RE = re.compile(r"(?<![A-Za-z0-9_])@[A-Za-z0-9_]{1,32}")
57
+ _REDDIT_RE = re.compile(r"(?<![A-Za-z0-9_])(?:u|r)/[A-Za-z0-9_]+")
58
+ _HASHTAG_RE = re.compile(r"(?<![A-Za-z0-9_])#(?!\s)[^\s#]+")
59
+
60
+ # `.map` / `.env` / `.gitignore`
61
+ _DOT_TOKEN_RE = re.compile(r"(?<![A-Za-z0-9_])\.(?=[A-Za-z0-9._-]*[A-Za-z0-9])[A-Za-z0-9._-]+")
62
+
63
+ # `app.js.map` / `index.d.ts` / `v2.3.1` / `foo/bar-baz.py` 等
64
+ _FILELIKE_RE = re.compile(
65
+ r"(?<![A-Za-z0-9_])"
66
+ r"(?=[A-Za-z0-9._/+:-]*[A-Za-z])"
67
+ r"(?=[A-Za-z0-9._/+:-]*[._/+:-])"
68
+ r"[A-Za-z0-9][A-Za-z0-9._/+:-]*"
69
+ r"(?![A-Za-z0-9_])"
70
+ )
71
+
72
+ # 参与“中英混排边界补空格”的 token:必须至少含 1 个拉丁字母,或本身就是受保护 token
73
+ _LATINISH = rf"(?:{_PROT}|(?=[A-Za-z0-9._/+:-]*[A-Za-z])[A-Za-z0-9][A-Za-z0-9._/+:-]*)"
74
+
75
+ # 零宽字符
76
+ _ZERO_WIDTH_RE = re.compile(r"[\u200b-\u200d\ufeff]")
77
+ _TRAILING_CLOSERS = set('"\')]})】》〉」』”’')
78
+
79
+
80
+ # ---------------------------
81
+ # 主函数
82
+ # ---------------------------
83
+
84
+ def normalize_tts_text(text: str) -> str:
85
+ """对 TTS 输入做鲁棒性正则化。"""
86
+ text = _base_cleanup(text)
87
+ text = _normalize_markdown_and_lines(text)
88
+ text, protected = _protect_spans(text)
89
+
90
+ text = _normalize_spaces(text)
91
+ text = _normalize_structural_punctuation(text)
92
+ text = _normalize_repeated_punctuation(text)
93
+ text = _normalize_spaces(text)
94
+
95
+ text = _restore_spans(text, protected)
96
+ text = text.strip()
97
+ return _ensure_terminal_punctuation_by_line(text)
98
+
99
+
100
+ # ---------------------------
101
+ # 具体规则
102
+ # ---------------------------
103
+
104
+ def _base_cleanup(text: str) -> str:
105
+ text = text.replace("\r\n", "\n").replace("\r", "\n").replace("\u3000", " ")
106
+ text = _ZERO_WIDTH_RE.sub("", text)
107
+
108
+ cleaned = []
109
+ for ch in text:
110
+ cat = unicodedata.category(ch)
111
+ if ch in "\n\t " or not cat.startswith("C"):
112
+ cleaned.append(ch)
113
+ return "".join(cleaned)
114
+
115
+
116
+ def _normalize_markdown_and_lines(text: str) -> str:
117
+ # Markdown 链接:[text](url) -> text url
118
+ text = re.sub(r"\[([^\[\]]+?)\]\((https?://[^)\s]+)\)", r"\1 \2", text)
119
+
120
+ lines = []
121
+ for raw in text.splitlines():
122
+ line = raw.strip()
123
+ if not line:
124
+ continue
125
+
126
+ line = re.sub(r"^#{1,6}\s+", "", line) # 标题
127
+ line = re.sub(r"^>\s+", "", line) # 引用
128
+ line = re.sub(r"^[-*+]\s+", "", line) # 无序列表
129
+ line = re.sub(r"^\d+[.)]\s+", "", line) # 有序列表
130
+ lines.append(line)
131
+
132
+ if not lines:
133
+ return ""
134
+
135
+ merged: list[str] = [lines[0]]
136
+ for line in lines[1:]:
137
+ previous = merged[-1]
138
+ merged[-1] = _ensure_terminal_punctuation(previous)
139
+ merged.append(line)
140
+ return "".join(merged)
141
+
142
+
143
+ def _protect_spans(text: str) -> tuple[str, list[str]]:
144
+ protected: list[str] = []
145
+
146
+ def repl(match: re.Match[str]) -> str:
147
+ idx = len(protected)
148
+ protected.append(match.group(0))
149
+ return f"___PROT{idx}___"
150
+
151
+ for pattern in (
152
+ _URL_RE,
153
+ _EMAIL_RE,
154
+ _MENTION_RE,
155
+ _REDDIT_RE,
156
+ _HASHTAG_RE,
157
+ _DOT_TOKEN_RE,
158
+ _FILELIKE_RE,
159
+ ):
160
+ text = pattern.sub(repl, text)
161
+
162
+ return text, protected
163
+
164
+
165
+ def _restore_spans(text: str, protected: list[str]) -> str:
166
+ for idx, original in enumerate(protected):
167
+ text = text.replace(f"___PROT{idx}___", original)
168
+ return text
169
+
170
+
171
+ def _normalize_spaces(text: str) -> str:
172
+ # 统一空白
173
+ text = re.sub(r"[ \t\r\f\v]+", " ", text)
174
+
175
+ # 汉字 / 日文片段内部:删除空格
176
+ text = re.sub(rf"({_CJK})\s+(?={_CJK})", r"\1", text)
177
+
178
+ # 汉字 / 日文 与纯数字之间:删除空格(不强行保留)
179
+ text = re.sub(rf"({_CJK})\s+(?=\d)", r"\1", text)
180
+ text = re.sub(rf"(\d)\s+(?={_CJK})", r"\1", text)
181
+
182
+ # 汉字 / 日文 与拉丁字母类 token / protected token 相邻:保留或补 1 个空格
183
+ text = re.sub(rf"({_CJK})(?=({_LATINISH}))", r"\1 ", text)
184
+ text = re.sub(rf"(({_LATINISH}))(?={_CJK})", r"\1 ", text)
185
+
186
+ # 再压一遍连续空格
187
+ text = re.sub(r" {2,}", " ", text)
188
+
189
+ # 中文标点前后不保留空格
190
+ text = re.sub(r"\s+([,。!?;:、”’」』】)》])", r"\1", text)
191
+ text = re.sub(r"([(【「『《“‘])\s+", r"\1", text)
192
+ text = re.sub(r"([,。!?;:、])\s*", r"\1", text)
193
+
194
+ # ASCII 标点前不留空格;后面的英文空格不强改
195
+ text = re.sub(r"\s+([,.;!?])", r"\1", text)
196
+
197
+ return re.sub(r" {2,}", " ", text).strip()
198
+
199
+
200
+ def _normalize_structural_punctuation(text: str) -> str:
201
+ # 各类结构性括号:统一转成双引号包裹内容
202
+ text = re.sub(r"\[\s*([^\[\]]+?)\s*\]", r'"\1"', text)
203
+ text = re.sub(r"\{\s*([^{}]+?)\s*\}", r'"\1"', text)
204
+ text = re.sub(r"[【〖『「]\s*([^】〗』」]+?)\s*[】〗』」]", r'"\1"', text)
205
+
206
+ # 《》只处理独立标题,不处理嵌入式标题
207
+ # 例:重磅。《新品发布》——现在开始! -> 重磅。新品发布。现在开始!
208
+ text = re.sub(
209
+ r"(^|[。!?!?;;]\s*)《([^》]+)》(?=\s*(?:___PROT\d+___|[—–―-]{2,}|$|[。!?!?;;,,]))",
210
+ r"\1\2",
211
+ text,
212
+ )
213
+
214
+ # 长破折号 / 多连字符:转句边界
215
+ text = re.sub(r"\s*(?:—|–|―|-){2,}\s*", "。", text)
216
+
217
+ return text
218
+
219
+
220
+ def _normalize_repeated_punctuation(text: str) -> str:
221
+ # 省略号 / 连续句点
222
+ text = re.sub(r"(?:\.{3,}|…{2,}|……+)", "。", text)
223
+
224
+ # 同类重复标点
225
+ text = re.sub(r"[。.]{2,}", "。", text)
226
+ text = re.sub(r"[,,]{2,}", ",", text)
227
+ text = re.sub(r"[!!]{2,}", "!", text)
228
+ text = re.sub(r"[??]{2,}", "?", text)
229
+
230
+ # 混合问叹号:收敛到 ?!
231
+ def _mixed_qe(match: re.Match[str]) -> str:
232
+ s = match.group(0)
233
+ has_q = any(ch in s for ch in "??")
234
+ has_e = any(ch in s for ch in "!!")
235
+ if has_q and has_e:
236
+ return "?!"
237
+ return "?" if has_q else "!"
238
+
239
+ text = re.sub(r"[!?!?]{2,}", _mixed_qe, text)
240
+ return text
241
+
242
+
243
+ def _ensure_terminal_punctuation(text: str) -> str:
244
+ if not text:
245
+ return text
246
+
247
+ index = len(text) - 1
248
+ while index >= 0 and text[index].isspace():
249
+ index -= 1
250
+ while index >= 0 and text[index] in _TRAILING_CLOSERS:
251
+ index -= 1
252
+
253
+ if index >= 0 and unicodedata.category(text[index]).startswith("P"):
254
+ return text
255
+ return text + "。"
256
+
257
+
258
+ def _ensure_terminal_punctuation_by_line(text: str) -> str:
259
+ if not text:
260
+ return text
261
+ lines = text.split("\n")
262
+ normalized_lines = [_ensure_terminal_punctuation(line.strip()) if line.strip() else "" for line in lines]
263
+ return "\n".join(normalized_lines).strip()
264
+
265
+
266
+ # ---------------------------
267
+ # 测试
268
+ # ---------------------------
269
+
270
+ TEST_CASES = [
271
+ # 1) .map / dot-leading token / 文件名 / 版本号
272
+ (
273
+ "dot_map_sentence",
274
+ "2026 年 3 月 31 日,安全研究员 Chaofan Shou (@Fried_rice) 发现 Anthropic 的 npm 包中暴露了 .map 文件,",
275
+ "2026年3月31日,安全研究员 Chaofan Shou (@Fried_rice) 发现 Anthropic 的 npm 包中暴露了 .map 文件,",
276
+ ),
277
+ ("dot_tokens", "别把 .env、.npmrc、.gitignore 提交上去。", "别把 .env、.npmrc、.gitignore 提交上去。"),
278
+ ("file_names", "请检查 bundle.min.js、package.json 和 processing_moss_tts.py。", "请检查 bundle.min.js、package.json 和 processing_moss_tts.py。"),
279
+ ("index_d_ts", "index.d.ts 里也有同样的问题。", "index.d.ts 里也有同样的问题。"),
280
+ ("version_build", "Bug 的讨论可以精确到 v2.3.1 (Build 15)。", "Bug 的讨论可以精确到 v2.3.1 (Build 15)。"),
281
+ ("version_rc", "3.0.0-rc.1 还不能上生产。", "3.0.0-rc.1 还不能上生产。"),
282
+ ("jar_name", "fabric-api-0.91.3+1.20.2.jar 需要单独下载。", "fabric-api-0.91.3+1.20.2.jar 需要单独下载。"),
283
+
284
+ # 2) URL / Email / mention / hashtag / Reddit
285
+ ("url", "仓库地址是 https://github.com/instructkr/claude-code", "仓库地址是 https://github.com/instructkr/claude-code。"),
286
+ ("email", "联系邮箱:ops+tts@example.ai", "联系邮箱:ops+tts@example.ai。"),
287
+ ("mention", "@Fried_rice 说这是 source map 暴露。", "@Fried_rice 说这是 source map 暴露。"),
288
+ ("reddit", "去 r/singularity 看讨论。", "去 r/singularity 看讨论。"),
289
+ ("hashtag_chain", "#张雪峰#张雪峰[话题]#张雪峰事件", "#张雪峰#张雪峰[话题]#张雪峰事件。"),
290
+ ("mention_hashtag_boundary", "关注@biscuit0228_并转发#thetime_tbs", "关注 @biscuit0228_ 并转发 #thetime_tbs。"),
291
+
292
+ # 3) bracket / 控制 token:统一转成双引号
293
+ ("speaker_bracket", "[S1]你好。[S2]收到。", '"S1"你好。"S2"收到。'),
294
+ ("event_bracket", "请模仿 {whisper} 的语气说“别出声”。", '请模仿 "whisper" 的语气说“别出声”。'),
295
+ ("order_bracket", "订单号:[AB-1234-XYZ]", '订单号:"AB-1234-XYZ"。'),
296
+
297
+ # 4) 结构性符号:转成双引号或句边界,而不是直接删除
298
+ ("struct_headline", "〖重磅〗《新品发布》——现在开始!", '"重磅"《新品发布》。现在开始!'),
299
+ ("struct_notice", "【公告】今天 20:00 维护——预计 30 分钟。", '"公告"今天20:00维护。预计30分钟。'),
300
+ ("struct_quote_chain", "『特别提醒』「不要外传」", '"特别提醒""不要外传"。'),
301
+ ("struct_embedded_quote", "他说【重要通知】明天发布。", '他说"重要通知"明天发布。'),
302
+
303
+ # 5) 嵌入式标题:保留
304
+ ("embedded_title", "我喜欢《哈姆雷特》这本书。", "我喜欢《哈姆雷特》这本书。"),
305
+
306
+ # 6) 重复标点 / 社交噪声
307
+ ("noise_qe", "真的假的???!!!", "真的假的?!"),
308
+ ("noise_ellipsis", "这个包把 app.js.map 也发上去了......太离谱了!!!", "这个包把 app.js.map 也发上去了。太离谱了!"),
309
+ ("noise_ellipsis_cn", "【系统提示】请模仿{sad}低沉语气,说“今天下雨了……”", '"系统提示"请模仿"sad"低沉语气,说“今天下雨了。”'),
310
+
311
+ # 7) 空格规则:英文压缩、中文删除、中英混排保留边界
312
+ ("english_spaces", "This is a test.", "This is a test."),
313
+ ("chinese_spaces", "这 是 一 段 含有多种空白的文本。", "这是一段含有多种空白的文本。"),
314
+ ("mixed_spaces_1", "这是Anthropic的npm包", "这是 Anthropic 的 npm 包。"),
315
+ ("mixed_spaces_2", "今天update到v2.3.1了", "今天 update 到 v2.3.1 了。"),
316
+ ("mixed_spaces_3", "处理app.js.map文件", "处理 app.js.map 文件。"),
317
+
318
+ # 8) Markdown / 列表 / 换行
319
+ ("markdown_link", "详情见 [release note](https://github.com/example/release)", "详情见 release note https://github.com/example/release。"),
320
+ ("markdown_heading", "# I made a free open source app to help with markdown files", "I made a free open source app to help with markdown files。"),
321
+ ("list_lines", "- 修复 .map 泄露\n- 发布 v2.3.1", "修复 .map 泄露。发布 v2.3.1。"),
322
+ ("numbered_lines", "1. 安装依赖\n2. 运行测试\n3. 发布 v2.3.1", "安装依赖。运行测试。发布 v2.3.1。"),
323
+ ("newlines", "第一行\n第二行\n第三行", "第一行。第二行。第三行。"),
324
+
325
+ # 9) 句末补标点
326
+ ("terminal_punct_plain", "今天发布", "今天发布。"),
327
+ ("terminal_punct_quoted", '他说"你好"', '他说"你好"。'),
328
+ ("terminal_punct_existing", "今天发布。", "今天发布。"),
329
+ ("terminal_punct_newlines", "第一行\n第二行。", "第一行。第二行。"),
330
+ ("terminal_punct_blank_lines", "第一行\n\n第二行", "第一行。第二行。"),
331
+
332
+ # 10) 零宽字符 / 幂等性
333
+ ("zero_width_url", "详见 https://x.com/\u200bSafety", "详见 https://x.com/Safety。"),
334
+ ]
335
+
336
+
337
+ def run_tests(verbose: bool = True) -> None:
338
+ failed = []
339
+
340
+ for name, text, expected in TEST_CASES:
341
+ actual = normalize_tts_text(text)
342
+ if actual != expected:
343
+ failed.append((name, text, expected, actual))
344
+ continue
345
+
346
+ # 幂等性:第二次归一化不应继续改动结果
347
+ second = normalize_tts_text(actual)
348
+ if second != actual:
349
+ failed.append((name + "_idempotence", actual, actual, second))
350
+
351
+ if failed:
352
+ lines = ["\nTEST FAILED:\n"]
353
+ for name, text, expected, actual in failed:
354
+ lines.append(f"[{name}]")
355
+ lines.append(f"input : {text}")
356
+ lines.append(f"expected: {expected}")
357
+ lines.append(f"actual : {actual}")
358
+ lines.append("")
359
+ raise AssertionError("\n".join(lines))
360
+
361
+ if verbose:
362
+ print(f"All {len(TEST_CASES)} tests passed.")
363
+
364
+
365
+ if __name__ == "__main__":
366
+ run_tests()
weights/codec/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
weights/codec/README.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ tags:
5
+ - audio
6
+ - audio-tokenizer
7
+ - neural-codec
8
+ - moss-tts-family
9
+ - MOSS Audio Tokenizer
10
+ - speech-tokenizer
11
+ - trust-remote-code
12
+ ---
13
+
14
+ # MossAudioTokenizer
15
+
16
+ This is the code for MOSS-Audio-Tokenizer presented in [MOSS-Audio-Tokenizer: Scaling Audio Tokenizers for Future Audio Foundation Models](https://arxiv.org/abs/2602.10934).
17
+
18
+ **MOSSAudioTokenizer** is a unified discrete audio tokenizer based on the **Cat** (**C**ausal **A**udio **T**okenizer with **T**ransformer) architecture. Scaling to 1.6 billion parameters, it functions as a unified discrete interface, delivering both lossless-quality reconstruction and high-level semantic alignment.
19
+
20
+ **Key Features:**
21
+
22
+ * **Extreme Compression & Variable Bitrate**: It compresses 48kHz stereo audio into a remarkably low frame rate of 12.5Hz. Utilizing a 32-layer Residual LFQ quantizer stack, it supports high-fidelity reconstruction across a wide range of bitrates.
23
+ * **Pure Transformer Architecture**: The model features a "CNN-free" homogeneous architecture built entirely from Causal Transformer blocks. With 1.6B combined parameters (Encoder + Decoder), it ensures exceptional scalability and supports low-latency streaming inference.
24
+ * **Large-Scale General Audio Training**: Trained on 3 million hours of diverse audio data, the model excels at encoding and reconstructing all audio domains, including speech, sound effects, and music.
25
+ * **Unified Semantic-Acoustic Representation**: While achieving state-of-the-art reconstruction quality, Cat produces discrete tokens that are "semantic-rich," making them ideal for downstream tasks like speech understanding (ASR) and generation (TTS).
26
+ * **Fully Trained From Scratch**: Cat does not rely on any pretrained encoders (such as HuBERT or Whisper) or distillation from teacher models. All representations are learned autonomously from raw data.
27
+ * **End-to-End Joint Optimization**: All components—including the encoder, quantizer, decoder, discriminator, and a decoder-only LLM for semantic alignment—are optimized jointly in a single unified training pipeline.
28
+
29
+ **Summary:**
30
+ By combining a simple, scalable architecture with massive-scale data, the Cat architecture overcomes the bottlenecks of traditional audio tokenizers. It provides a robust, high-fidelity, and semantically grounded interface for the next generation of native audio foundation models.
31
+
32
+ This repository contains a lightweight remote-code implementation that mirrors the current 🤗 Transformers
33
+ `transformers.models.moss_audio_tokenizer` module. It is intended to be uploaded to a Hugging Face Hub model repository
34
+ and loaded with `trust_remote_code=True` when needed.
35
+
36
+
37
+ ## Usage
38
+
39
+ ### Quickstart
40
+
41
+ ```python
42
+ import torch
43
+ from transformers import AutoModel
44
+ import torchaudio
45
+
46
+ repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
47
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
48
+
49
+ wav, sr = torchaudio.load('demo/demo_gt.wav')
50
+ if sr != model.sampling_rate:
51
+ wav = torchaudio.functional.resample(wav, sr, model.sampling_rate)
52
+ if wav.shape[0] == 1:
53
+ wav = wav.repeat(model.config.number_channels, 1)
54
+ else:
55
+ wav = wav[: model.config.number_channels]
56
+ wav = wav.unsqueeze(0)
57
+ enc = model.encode(wav, return_dict=True)
58
+ print(f"enc.audio_codes.shape: {enc.audio_codes.shape}")
59
+ dec = model.decode(enc.audio_codes, return_dict=True)
60
+ print(f"dec.audio.shape: {dec.audio.shape}")
61
+ wav = dec.audio.squeeze(0)
62
+ torchaudio.save("demo/demo_rec.wav", wav, sample_rate=model.sampling_rate)
63
+
64
+ # Decode using only the first 8 layers of the RVQ
65
+ dec_rvq8 = model.decode(enc.audio_codes[:8], return_dict=True)
66
+ wav_rvq8 = dec_rvq8.audio.squeeze(0)
67
+ torchaudio.save("demo/demo_rec_rvq8.wav", wav_rvq8, sample_rate=model.sampling_rate)
68
+ ```
69
+
70
+ ### Attention Backend And Compute Dtype
71
+
72
+ `config.attention_implementation` controls whether transformer layers prefer `sdpa` or `flash_attention_2`.
73
+ `config.compute_dtype` controls the non-quantizer autocast dtype and supports `fp32`, `bf16`, and `fp16`.
74
+
75
+ ```python
76
+ model.set_attention_implementation("flash_attention_2")
77
+ model.set_compute_dtype("fp16")
78
+ ```
79
+
80
+ The quantizer always runs in fp32.
81
+
82
+ ### Streaming
83
+
84
+ `MossAudioTokenizerModel.encode`, `decode`, `batch_encode`, and `batch_decode` all support streaming through a
85
+ `chunk_duration` argument.
86
+
87
+ - `chunk_duration` is expressed in seconds.
88
+ - `chunk_duration * MossAudioTokenizerConfig.sampling_rate` must be divisible by `MossAudioTokenizerConfig.downsample_rate`.
89
+ - Streaming batch inference is supported.
90
+ - The public waveform interface expects stereo inputs shaped `(2, T)` or batched stereo inputs shaped `(B, 2, T)`.
91
+
92
+ ```python
93
+ import torch
94
+ from transformers import AutoModel
95
+
96
+ repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
97
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
98
+ audio = torch.randn(2, 48000 * 6) # dummy stereo waveform
99
+
100
+ # 6.0s @ 48kHz = 288000 samples, divisible by downsample_rate=3840
101
+ enc = model.encode(audio.unsqueeze(0), return_dict=True, chunk_duration=0.08)
102
+ dec = model.decode(enc.audio_codes, return_dict=True, chunk_duration=0.08)
103
+
104
+ batch_enc = model.batch_encode([audio, audio[:, : 48000 * 3]], chunk_duration=0.08)
105
+ codes_list = [
106
+ batch_enc.audio_codes[:, i, : batch_enc.audio_codes_lengths[i]]
107
+ for i in range(batch_enc.audio_codes.shape[1])
108
+ ]
109
+ batch_dec = model.batch_decode(codes_list, chunk_duration=0.08)
110
+ ```
111
+
112
+ #### Continuous Batch Streaming Decode
113
+
114
+ For decoder-side continuous batching, prefer `batch_decode(..., streaming=True, ...)`.
115
+
116
+ - The first streaming call may pass `max_batch_size=...`. If it is omitted, the first batch size reserves the
117
+ fixed-slot decoder budget for that public stream.
118
+ - Same-size calls continue the existing logical rows in-order.
119
+ - If a later call is larger, the new rows are admitted by tail append.
120
+ - `finalize_indices` means "decode these rows one last time, then evict them". The indices are interpreted against the
121
+ pre-call logical order.
122
+ - After a finalize call returns, the next streaming call may use the smaller survivor batch.
123
+ - `reset_stream=True` discards the hidden public streaming state and starts a fresh stream.
124
+
125
+ Milestone 1 boundaries:
126
+
127
+ - decode-only continuous batching
128
+ - one active streaming decode state per model instance
129
+ - fixed-slot decoder reservation from `max_batch_size`
130
+ - no encode-side continuous batching
131
+ - no physical compaction of surviving decode slots
132
+ - no multi-session concurrency on one model instance
133
+
134
+ ```python
135
+ import torch
136
+ from transformers import AutoModel
137
+
138
+ repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
139
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
140
+ num_quantizers = model.config.quantizer_kwargs["num_quantizers"]
141
+
142
+ codes_a0 = torch.randint(0, 8, (num_quantizers, 2))
143
+ codes_b0 = torch.randint(0, 8, (num_quantizers, 3))
144
+ codes_a1 = torch.randint(0, 8, (num_quantizers, 2))
145
+ codes_b1 = torch.randint(0, 8, (num_quantizers, 2))
146
+ codes_c0 = torch.randint(0, 8, (num_quantizers, 1))
147
+ codes_a2 = torch.randint(0, 8, (num_quantizers, 1))
148
+ codes_b2 = torch.randint(0, 8, (num_quantizers, 2))
149
+ codes_c1 = torch.randint(0, 8, (num_quantizers, 2))
150
+ codes_b3 = torch.randint(0, 8, (num_quantizers, 1))
151
+ codes_c2 = torch.randint(0, 8, (num_quantizers, 1))
152
+
153
+ # First call reserves 3 fixed decoder slots for A and B.
154
+ out_ab0 = model.batch_decode(
155
+ [codes_a0, codes_b0],
156
+ streaming=True,
157
+ max_batch_size=3,
158
+ reset_stream=True,
159
+ )
160
+
161
+ # Same logical rows continue in-order; C is a tail append.
162
+ out_abc1 = model.batch_decode(
163
+ [codes_a1, codes_b1, codes_c0],
164
+ streaming=True,
165
+ )
166
+
167
+ # Finalize A against the pre-call logical order. A still decodes in this call,
168
+ # then is evicted immediately afterward.
169
+ out_abc2 = model.batch_decode(
170
+ [codes_a2, codes_b2, codes_c1],
171
+ streaming=True,
172
+ finalize_indices=[0],
173
+ )
174
+
175
+ # The next call can shrink to the surviving logical rows only.
176
+ out_bc3 = model.batch_decode(
177
+ [codes_b3, codes_c2],
178
+ streaming=True,
179
+ )
180
+ ```
181
+
182
+ ## Repository layout
183
+
184
+ - `configuration_moss_audio_tokenizer.py`
185
+ - `modeling_moss_audio_tokenizer.py`
186
+ - `__init__.py`
187
+ - `config.json`
188
+ - model weights
189
+
190
+
191
+ ## Citation
192
+ If you use this code or result in your paper, please cite our work as:
193
+ ```tex
194
+
195
+ ```
weights/codec/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Remote code package for Moss audio tokenizer."""
weights/codec/config.json ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MossAudioTokenizerModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_moss_audio_tokenizer.MossAudioTokenizerConfig",
7
+ "AutoModel": "modeling_moss_audio_tokenizer.MossAudioTokenizerModel"
8
+ },
9
+ "model_type": "moss-audio-tokenizer",
10
+ "sample_rate": 48000,
11
+ "sampling_rate": 48000,
12
+ "downsample_rate": 3840,
13
+ "causal_transformer_context_duration": 10.0,
14
+ "number_channels": 2,
15
+ "enable_channel_interleave": true,
16
+ "attention_implementation": "sdpa",
17
+ "compute_dtype": "fp32",
18
+ "dtype": "float32",
19
+ "code_dim": 768,
20
+ "encoder_kwargs": [
21
+ {
22
+ "module_type": "PatchedPretransform",
23
+ "patch_size": 240
24
+ },
25
+ {
26
+ "causal": true,
27
+ "context_duration": 4.0,
28
+ "conv_layout": true,
29
+ "d_model": 256,
30
+ "dim_feedforward": 1024,
31
+ "gating": "none",
32
+ "input_dimension": 240,
33
+ "layer_scale": 0.01,
34
+ "max_period": 10000,
35
+ "module_type": "Transformer",
36
+ "norm": "layer_norm",
37
+ "num_heads": 4,
38
+ "num_layers": 4,
39
+ "output_dimension": 384,
40
+ "positional_embedding": "rope"
41
+ },
42
+ {
43
+ "module_type": "PatchedPretransform",
44
+ "patch_size": 2
45
+ },
46
+ {
47
+ "causal": true,
48
+ "context_duration": 6.0,
49
+ "conv_layout": true,
50
+ "d_model": 256,
51
+ "dim_feedforward": 1024,
52
+ "gating": "none",
53
+ "input_dimension": 768,
54
+ "layer_scale": 0.01,
55
+ "max_period": 10000,
56
+ "module_type": "Transformer",
57
+ "norm": "layer_norm",
58
+ "num_heads": 4,
59
+ "num_layers": 2,
60
+ "output_dimension": 384,
61
+ "positional_embedding": "rope"
62
+ },
63
+ {
64
+ "module_type": "PatchedPretransform",
65
+ "patch_size": 2
66
+ },
67
+ {
68
+ "causal": true,
69
+ "context_duration": 8.0,
70
+ "conv_layout": true,
71
+ "d_model": 256,
72
+ "dim_feedforward": 1024,
73
+ "gating": "none",
74
+ "input_dimension": 768,
75
+ "layer_scale": 0.01,
76
+ "max_period": 10000,
77
+ "module_type": "Transformer",
78
+ "norm": "layer_norm",
79
+ "num_heads": 4,
80
+ "num_layers": 2,
81
+ "output_dimension": 384,
82
+ "positional_embedding": "rope"
83
+ },
84
+ {
85
+ "module_type": "PatchedPretransform",
86
+ "patch_size": 2
87
+ },
88
+ {
89
+ "causal": true,
90
+ "context_duration": 10.0,
91
+ "conv_layout": true,
92
+ "d_model": 256,
93
+ "dim_feedforward": 1024,
94
+ "gating": "none",
95
+ "input_dimension": 768,
96
+ "layer_scale": 0.01,
97
+ "max_period": 10000,
98
+ "module_type": "Transformer",
99
+ "norm": "layer_norm",
100
+ "num_heads": 4,
101
+ "num_layers": 4,
102
+ "output_dimension": 192,
103
+ "positional_embedding": "rope"
104
+ },
105
+ {
106
+ "module_type": "PatchedPretransform",
107
+ "patch_size": 4
108
+ }
109
+ ],
110
+ "decoder_kwargs": [
111
+ {
112
+ "module_type": "PatchedPretransform",
113
+ "patch_size": 4
114
+ },
115
+ {
116
+ "causal": true,
117
+ "context_duration": 10.0,
118
+ "conv_layout": true,
119
+ "d_model": 256,
120
+ "dim_feedforward": 1024,
121
+ "gating": "none",
122
+ "input_dimension": 192,
123
+ "layer_scale": 0.01,
124
+ "max_period": 10000,
125
+ "module_type": "Transformer",
126
+ "norm": "layer_norm",
127
+ "num_heads": 4,
128
+ "num_layers": 4,
129
+ "output_dimension": 768,
130
+ "positional_embedding": "rope"
131
+ },
132
+ {
133
+ "module_type": "PatchedPretransform",
134
+ "patch_size": 2
135
+ },
136
+ {
137
+ "causal": true,
138
+ "context_duration": 8.0,
139
+ "conv_layout": true,
140
+ "d_model": 256,
141
+ "dim_feedforward": 1024,
142
+ "gating": "none",
143
+ "input_dimension": 384,
144
+ "layer_scale": 0.01,
145
+ "max_period": 10000,
146
+ "module_type": "Transformer",
147
+ "norm": "layer_norm",
148
+ "num_heads": 4,
149
+ "num_layers": 2,
150
+ "output_dimension": 768,
151
+ "positional_embedding": "rope"
152
+ },
153
+ {
154
+ "module_type": "PatchedPretransform",
155
+ "patch_size": 2
156
+ },
157
+ {
158
+ "causal": true,
159
+ "context_duration": 6.0,
160
+ "conv_layout": true,
161
+ "d_model": 256,
162
+ "dim_feedforward": 1024,
163
+ "gating": "none",
164
+ "input_dimension": 384,
165
+ "layer_scale": 0.01,
166
+ "max_period": 10000,
167
+ "module_type": "Transformer",
168
+ "norm": "layer_norm",
169
+ "num_heads": 4,
170
+ "num_layers": 2,
171
+ "output_dimension": 768,
172
+ "positional_embedding": "rope"
173
+ },
174
+ {
175
+ "module_type": "PatchedPretransform",
176
+ "patch_size": 2
177
+ },
178
+ {
179
+ "causal": true,
180
+ "context_duration": 4.0,
181
+ "conv_layout": true,
182
+ "d_model": 256,
183
+ "dim_feedforward": 1024,
184
+ "gating": "none",
185
+ "input_dimension": 384,
186
+ "layer_scale": 0.01,
187
+ "max_period": 10000,
188
+ "module_type": "Transformer",
189
+ "norm": "layer_norm",
190
+ "num_heads": 4,
191
+ "num_layers": 4,
192
+ "output_dimension": 240,
193
+ "positional_embedding": "rope"
194
+ },
195
+ {
196
+ "module_type": "PatchedPretransform",
197
+ "patch_size": 240
198
+ }
199
+ ],
200
+ "quantizer_type": "rlfq",
201
+ "quantizer_kwargs": {
202
+ "codebook_dim": 8,
203
+ "codebook_loss_weight": 1.0,
204
+ "codebook_size": 1024,
205
+ "commitment_loss_weight": 0.25,
206
+ "input_dim": 768,
207
+ "num_quantizers": 16,
208
+ "output_dim": 768,
209
+ "quantizer_dropout": 1.0,
210
+ "quantizer_type": "rlfq",
211
+ "rvq_dim": 512
212
+ },
213
+ "transformers_version": "4.56.0.dev0",
214
+ "reversed_decoder_kwargs": [
215
+ {
216
+ "module_type": "PatchedPretransform",
217
+ "patch_size": 240
218
+ },
219
+ {
220
+ "causal": true,
221
+ "context_duration": 4.0,
222
+ "conv_layout": true,
223
+ "d_model": 256,
224
+ "dim_feedforward": 1024,
225
+ "gating": "none",
226
+ "input_dimension": 240,
227
+ "layer_scale": 0.01,
228
+ "max_period": 10000,
229
+ "module_type": "Transformer",
230
+ "norm": "layer_norm",
231
+ "num_heads": 4,
232
+ "num_layers": 4,
233
+ "output_dimension": 384,
234
+ "positional_embedding": "rope"
235
+ },
236
+ {
237
+ "module_type": "PatchedPretransform",
238
+ "patch_size": 2
239
+ },
240
+ {
241
+ "causal": true,
242
+ "context_duration": 6.0,
243
+ "conv_layout": true,
244
+ "d_model": 256,
245
+ "dim_feedforward": 1024,
246
+ "gating": "none",
247
+ "input_dimension": 768,
248
+ "layer_scale": 0.01,
249
+ "max_period": 10000,
250
+ "module_type": "Transformer",
251
+ "norm": "layer_norm",
252
+ "num_heads": 4,
253
+ "num_layers": 2,
254
+ "output_dimension": 384,
255
+ "positional_embedding": "rope"
256
+ },
257
+ {
258
+ "module_type": "PatchedPretransform",
259
+ "patch_size": 2
260
+ },
261
+ {
262
+ "causal": true,
263
+ "context_duration": 8.0,
264
+ "conv_layout": true,
265
+ "d_model": 256,
266
+ "dim_feedforward": 1024,
267
+ "gating": "none",
268
+ "input_dimension": 768,
269
+ "layer_scale": 0.01,
270
+ "max_period": 10000,
271
+ "module_type": "Transformer",
272
+ "norm": "layer_norm",
273
+ "num_heads": 4,
274
+ "num_layers": 2,
275
+ "output_dimension": 384,
276
+ "positional_embedding": "rope"
277
+ },
278
+ {
279
+ "module_type": "PatchedPretransform",
280
+ "patch_size": 2
281
+ },
282
+ {
283
+ "causal": true,
284
+ "context_duration": 10.0,
285
+ "conv_layout": true,
286
+ "d_model": 256,
287
+ "dim_feedforward": 1024,
288
+ "gating": "none",
289
+ "input_dimension": 768,
290
+ "layer_scale": 0.01,
291
+ "max_period": 10000,
292
+ "module_type": "Transformer",
293
+ "norm": "layer_norm",
294
+ "num_heads": 4,
295
+ "num_layers": 4,
296
+ "output_dimension": 192,
297
+ "positional_embedding": "rope"
298
+ },
299
+ {
300
+ "module_type": "PatchedPretransform",
301
+ "patch_size": 4
302
+ }
303
+ ]
304
+ }
weights/codec/configuration_moss_audio_tokenizer.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MossAudioTokenizer model configuration."""
16
+
17
+ from typing import Any
18
+
19
+ try:
20
+ from transformers.configuration_utils import PreTrainedConfig
21
+ except ImportError:
22
+ from transformers.configuration_utils import PretrainedConfig as PreTrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class MossAudioTokenizerConfig(PreTrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`MossAudioTokenizerModel`]. It is used to instantiate a
32
+ MossAudioTokenizer model according to the specified arguments, defining the model architecture.
33
+
34
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
35
+ [VoiceAgentGroup/moss_audio_tokenizer](https://huggingface.co/VoiceAgentGroup/moss_audio_tokenizer) architecture.
36
+
37
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PreTrainedConfig`] for more information.
39
+
40
+ Args:
41
+ sampling_rate (`int`, *optional*, defaults to 48000):
42
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
43
+ downsample_rate (`int`, *optional*, defaults to 3840):
44
+ Total downsampling rate from waveform to tokens.
45
+ causal_transformer_context_duration (`float`, *optional*, defaults to 10.0):
46
+ Legacy global fallback context duration in seconds for causal transformer. If an individual transformer
47
+ entry in `encoder_kwargs` or `decoder_kwargs` provides `context_duration`, that per-module value takes
48
+ precedence.
49
+ encoder_kwargs (`list[dict]`, *optional*):
50
+ List of encoder module configurations. Each dict specifies a module type and its parameters.
51
+ decoder_kwargs (`list[dict]`, *optional*):
52
+ List of decoder module configurations in execution order.
53
+ number_channels (`int`, *optional*, defaults to 2):
54
+ Number of audio channels exposed by the public waveform interface.
55
+ enable_channel_interleave (`bool`, *optional*, defaults to `True`):
56
+ Whether to flatten multi-channel waveforms into a single internal stream before codec inference.
57
+ attention_implementation (`str`, *optional*, defaults to `"sdpa"`):
58
+ Attention implementation to prefer for transformer layers. Supported values are `"sdpa"` and
59
+ `"flash_attention_2"`.
60
+ compute_dtype (`str`, *optional*, defaults to `"fp32"`):
61
+ Inference compute dtype for non-quantizer modules. Supported values are `"fp32"`, `"bf16"`, and `"fp16"`.
62
+ quantizer_type (`str`, *optional*, defaults to `"rlfq"`):
63
+ Quantizer type. Options include `"rvq"`, `"spec_rvq"`, `"rlfq"`, `"random_prefix_rlfq"`.
64
+ quantizer_kwargs (`dict`, *optional*):
65
+ Configuration for the quantizer including `input_dim`, `rvq_dim`, `output_dim`, `num_quantizers`,
66
+ `codebook_size`, and `codebook_dim`.
67
+
68
+ Example:
69
+
70
+ ```python
71
+ >>> from transformers import MossAudioTokenizerModel, MossAudioTokenizerConfig
72
+
73
+ >>> # Initializing a MossAudioTokenizer style configuration
74
+ >>> configuration = MossAudioTokenizerConfig()
75
+
76
+ >>> # Initializing a model (with random weights) from the configuration
77
+ >>> model = MossAudioTokenizerModel(configuration)
78
+
79
+ >>> # Accessing the model configuration
80
+ >>> configuration = model.config
81
+ ```
82
+ """
83
+
84
+ model_type = "moss-audio-tokenizer"
85
+
86
+ # Backward-compatible alias used by some checkpoints.
87
+ attribute_map = {"sample_rate": "sampling_rate"}
88
+
89
+ sampling_rate: int
90
+ downsample_rate: int
91
+ causal_transformer_context_duration: float
92
+ encoder_kwargs: list[dict[str, Any]]
93
+ decoder_kwargs: list[dict[str, Any]]
94
+ number_channels: int
95
+ enable_channel_interleave: bool
96
+ attention_implementation: str
97
+ compute_dtype: str
98
+ quantizer_type: str
99
+ quantizer_kwargs: dict[str, Any]
100
+
101
+ def __init__(
102
+ self,
103
+ version: str | None = None,
104
+ sampling_rate: int = 48000,
105
+ downsample_rate: int = 3840,
106
+ causal_transformer_context_duration: float = 10.0,
107
+ encoder_kwargs: list[dict[str, Any]] | None = None,
108
+ decoder_kwargs: list[dict[str, Any]] | None = None,
109
+ number_channels: int = 2,
110
+ enable_channel_interleave: bool = True,
111
+ attention_implementation: str = "sdpa",
112
+ compute_dtype: str = "fp32",
113
+ quantizer_type: str = "rlfq",
114
+ quantizer_kwargs: dict[str, Any] | None = None,
115
+ **kwargs,
116
+ ):
117
+ # Some checkpoints might include an incorrect/legacy `model_type` (e.g. "speech_tokenizer").
118
+ # We drop it to avoid overriding the class-level `model_type`.
119
+ kwargs.pop("model_type", None)
120
+ if "channels_numbers" in kwargs:
121
+ number_channels = kwargs.pop("channels_numbers")
122
+ if "enable_channel_interleave" in kwargs:
123
+ enable_channel_interleave = kwargs.pop("enable_channel_interleave")
124
+ if "attention_backend" in kwargs and attention_implementation == "sdpa":
125
+ attention_implementation = kwargs.pop("attention_backend")
126
+ if "codec_compute_dtype" in kwargs and compute_dtype == "fp32":
127
+ compute_dtype = kwargs.pop("codec_compute_dtype")
128
+ reversed_decoder_kwargs = kwargs.pop("reversed_decoder_kwargs", None)
129
+
130
+ # `version` is accepted for compatibility but not used in modeling.
131
+ self.version = version
132
+ self.sampling_rate = sampling_rate
133
+ self.downsample_rate = downsample_rate
134
+ self.causal_transformer_context_duration = causal_transformer_context_duration
135
+ self.number_channels = number_channels
136
+ self.enable_channel_interleave = enable_channel_interleave
137
+ self.attention_implementation = attention_implementation
138
+ self.compute_dtype = compute_dtype
139
+ # Default encoder configuration
140
+ if encoder_kwargs is None:
141
+ encoder_kwargs = [
142
+ {
143
+ "module_type": "PatchedPretransform",
144
+ "patch_size": 240,
145
+ },
146
+ {
147
+ "module_type": "Transformer",
148
+ "input_dimension": 240,
149
+ "output_dimension": 384,
150
+ "d_model": 768,
151
+ "num_heads": 12,
152
+ "num_layers": 12,
153
+ "dim_feedforward": 3072,
154
+ "causal": True,
155
+ "norm": "layer_norm",
156
+ "positional_embedding": "rope",
157
+ "max_period": 10000,
158
+ "gating": "none",
159
+ "layer_scale": 0.01,
160
+ "conv_layout": True,
161
+ "context_duration": 1.0,
162
+ },
163
+ {
164
+ "module_type": "PatchedPretransform",
165
+ "patch_size": 2,
166
+ },
167
+ {
168
+ "module_type": "Transformer",
169
+ "input_dimension": 768,
170
+ "output_dimension": 384,
171
+ "d_model": 768,
172
+ "num_heads": 12,
173
+ "num_layers": 12,
174
+ "dim_feedforward": 3072,
175
+ "causal": True,
176
+ "norm": "layer_norm",
177
+ "positional_embedding": "rope",
178
+ "max_period": 10000,
179
+ "gating": "none",
180
+ "layer_scale": 0.01,
181
+ "conv_layout": True,
182
+ "context_duration": 2.0,
183
+ },
184
+ {
185
+ "module_type": "PatchedPretransform",
186
+ "patch_size": 2,
187
+ },
188
+ {
189
+ "module_type": "Transformer",
190
+ "input_dimension": 768,
191
+ "output_dimension": 384,
192
+ "d_model": 768,
193
+ "num_heads": 12,
194
+ "num_layers": 12,
195
+ "dim_feedforward": 3072,
196
+ "causal": True,
197
+ "norm": "layer_norm",
198
+ "positional_embedding": "rope",
199
+ "max_period": 10000,
200
+ "gating": "none",
201
+ "layer_scale": 0.01,
202
+ "conv_layout": True,
203
+ "context_duration": 4.0,
204
+ },
205
+ {
206
+ "module_type": "PatchedPretransform",
207
+ "patch_size": 2,
208
+ },
209
+ {
210
+ "module_type": "Transformer",
211
+ "input_dimension": 768,
212
+ "output_dimension": 384,
213
+ "d_model": 768,
214
+ "num_heads": 12,
215
+ "num_layers": 12,
216
+ "dim_feedforward": 3072,
217
+ "causal": True,
218
+ "norm": "layer_norm",
219
+ "positional_embedding": "rope",
220
+ "max_period": 10000,
221
+ "gating": "none",
222
+ "layer_scale": 0.01,
223
+ "conv_layout": True,
224
+ "context_duration": 8.0,
225
+ },
226
+ {
227
+ "module_type": "PatchedPretransform",
228
+ "patch_size": 2,
229
+ },
230
+ {
231
+ "module_type": "Transformer",
232
+ "input_dimension": 768,
233
+ "output_dimension": 640,
234
+ "d_model": 768,
235
+ "num_heads": 12,
236
+ "num_layers": 12,
237
+ "dim_feedforward": 3072,
238
+ "causal": True,
239
+ "norm": "layer_norm",
240
+ "positional_embedding": "rope",
241
+ "max_period": 10000,
242
+ "gating": "none",
243
+ "layer_scale": 0.01,
244
+ "conv_layout": True,
245
+ "context_duration": 10.0,
246
+ },
247
+ {
248
+ "module_type": "PatchedPretransform",
249
+ "patch_size": 2,
250
+ },
251
+ {
252
+ "module_type": "Transformer",
253
+ "input_dimension": 1280,
254
+ "output_dimension": 768,
255
+ "d_model": 1280,
256
+ "num_heads": 20,
257
+ "num_layers": 32,
258
+ "dim_feedforward": 5120,
259
+ "causal": True,
260
+ "norm": "layer_norm",
261
+ "positional_embedding": "rope",
262
+ "max_period": 10000,
263
+ "gating": "none",
264
+ "layer_scale": 0.01,
265
+ "conv_layout": True,
266
+ "context_duration": 10.0,
267
+ },
268
+ ]
269
+ else:
270
+ encoder_kwargs = [dict(module_kwargs) for module_kwargs in encoder_kwargs]
271
+ for module_kwargs in encoder_kwargs:
272
+ if module_kwargs.get("module_type") == "Transformer":
273
+ module_kwargs.setdefault("context_duration", causal_transformer_context_duration)
274
+ self.encoder_kwargs = encoder_kwargs
275
+
276
+ # Default decoder configuration (execution order)
277
+ if decoder_kwargs is None and reversed_decoder_kwargs is not None:
278
+ reversed_decoder_kwargs = [dict(module_kwargs) for module_kwargs in reversed_decoder_kwargs]
279
+ decoder_kwargs = []
280
+ for module_kwargs in reversed_decoder_kwargs[::-1]:
281
+ if module_kwargs.get("module_type") != "Transformer":
282
+ decoder_kwargs.append(module_kwargs)
283
+ continue
284
+ module_kwargs = dict(module_kwargs)
285
+ module_kwargs["input_dimension"], module_kwargs["output_dimension"] = (
286
+ module_kwargs["output_dimension"],
287
+ module_kwargs["input_dimension"],
288
+ )
289
+ decoder_kwargs.append(module_kwargs)
290
+
291
+ if decoder_kwargs is None:
292
+ decoder_kwargs = [
293
+ {
294
+ "module_type": "Transformer",
295
+ "input_dimension": 768,
296
+ "output_dimension": 1280,
297
+ "d_model": 1280,
298
+ "num_heads": 20,
299
+ "num_layers": 32,
300
+ "dim_feedforward": 5120,
301
+ "causal": True,
302
+ "norm": "layer_norm",
303
+ "positional_embedding": "rope",
304
+ "max_period": 10000,
305
+ "gating": "none",
306
+ "layer_scale": 0.01,
307
+ "conv_layout": True,
308
+ "context_duration": 10.0,
309
+ },
310
+ {
311
+ "module_type": "PatchedPretransform",
312
+ "patch_size": 2,
313
+ },
314
+ {
315
+ "module_type": "Transformer",
316
+ "input_dimension": 640,
317
+ "output_dimension": 768,
318
+ "d_model": 768,
319
+ "num_heads": 12,
320
+ "num_layers": 12,
321
+ "dim_feedforward": 3072,
322
+ "causal": True,
323
+ "norm": "layer_norm",
324
+ "positional_embedding": "rope",
325
+ "max_period": 10000,
326
+ "gating": "none",
327
+ "layer_scale": 0.01,
328
+ "conv_layout": True,
329
+ "context_duration": 10.0,
330
+ },
331
+ {
332
+ "module_type": "PatchedPretransform",
333
+ "patch_size": 2,
334
+ },
335
+ {
336
+ "module_type": "Transformer",
337
+ "input_dimension": 384,
338
+ "output_dimension": 768,
339
+ "d_model": 768,
340
+ "num_heads": 12,
341
+ "num_layers": 12,
342
+ "dim_feedforward": 3072,
343
+ "causal": True,
344
+ "norm": "layer_norm",
345
+ "positional_embedding": "rope",
346
+ "max_period": 10000,
347
+ "gating": "none",
348
+ "layer_scale": 0.01,
349
+ "conv_layout": True,
350
+ "context_duration": 8.0,
351
+ },
352
+ {
353
+ "module_type": "PatchedPretransform",
354
+ "patch_size": 2,
355
+ },
356
+ {
357
+ "module_type": "Transformer",
358
+ "input_dimension": 384,
359
+ "output_dimension": 768,
360
+ "d_model": 768,
361
+ "num_heads": 12,
362
+ "num_layers": 12,
363
+ "dim_feedforward": 3072,
364
+ "causal": True,
365
+ "norm": "layer_norm",
366
+ "positional_embedding": "rope",
367
+ "max_period": 10000,
368
+ "gating": "none",
369
+ "layer_scale": 0.01,
370
+ "conv_layout": True,
371
+ "context_duration": 4.0,
372
+ },
373
+ {
374
+ "module_type": "PatchedPretransform",
375
+ "patch_size": 2,
376
+ },
377
+ {
378
+ "module_type": "Transformer",
379
+ "input_dimension": 384,
380
+ "output_dimension": 768,
381
+ "d_model": 768,
382
+ "num_heads": 12,
383
+ "num_layers": 12,
384
+ "dim_feedforward": 3072,
385
+ "causal": True,
386
+ "norm": "layer_norm",
387
+ "positional_embedding": "rope",
388
+ "max_period": 10000,
389
+ "gating": "none",
390
+ "layer_scale": 0.01,
391
+ "conv_layout": True,
392
+ "context_duration": 2.0,
393
+ },
394
+ {
395
+ "module_type": "PatchedPretransform",
396
+ "patch_size": 2,
397
+ },
398
+ {
399
+ "module_type": "Transformer",
400
+ "input_dimension": 384,
401
+ "output_dimension": 240,
402
+ "d_model": 768,
403
+ "num_heads": 12,
404
+ "num_layers": 12,
405
+ "dim_feedforward": 3072,
406
+ "causal": True,
407
+ "norm": "layer_norm",
408
+ "positional_embedding": "rope",
409
+ "max_period": 10000,
410
+ "gating": "none",
411
+ "layer_scale": 0.01,
412
+ "conv_layout": True,
413
+ "context_duration": 1.0,
414
+ },
415
+ {
416
+ "module_type": "PatchedPretransform",
417
+ "patch_size": 240,
418
+ },
419
+ ]
420
+ else:
421
+ decoder_kwargs = [dict(module_kwargs) for module_kwargs in decoder_kwargs]
422
+ for module_kwargs in decoder_kwargs:
423
+ if module_kwargs.get("module_type") == "Transformer":
424
+ module_kwargs.setdefault("context_duration", causal_transformer_context_duration)
425
+ self.decoder_kwargs = decoder_kwargs
426
+
427
+ # Default quantizer configuration
428
+ if quantizer_kwargs is None:
429
+ quantizer_kwargs = {
430
+ "input_dim": 768,
431
+ "rvq_dim": 512,
432
+ "output_dim": 768,
433
+ "num_quantizers": 32,
434
+ "codebook_size": 1024,
435
+ "codebook_dim": 8,
436
+ "quantizer_type": "rlfq",
437
+ }
438
+
439
+ # Handle quantizer_type from kwargs or config
440
+ kw_qtype = quantizer_kwargs.get("quantizer_type", None)
441
+ if kw_qtype is not None:
442
+ self.quantizer_type = kw_qtype
443
+ else:
444
+ self.quantizer_type = quantizer_type
445
+ quantizer_kwargs["quantizer_type"] = quantizer_type
446
+
447
+ self.quantizer_kwargs = quantizer_kwargs
448
+
449
+ super().__init__(**kwargs)
450
+
451
+ @property
452
+ def num_quantizers(self) -> int:
453
+ """Return the number of quantizers from quantizer_kwargs."""
454
+ return self.quantizer_kwargs.get("num_quantizers", 32)
455
+
456
+ @property
457
+ def codebook_size(self) -> int:
458
+ """Return the codebook size from quantizer_kwargs."""
459
+ return self.quantizer_kwargs.get("codebook_size", 4096)
460
+
461
+ @property
462
+ def frame_rate(self) -> float:
463
+ """Return the frame rate (tokens per second)."""
464
+ return self.sampling_rate / self.downsample_rate
465
+
466
+
467
+ __all__ = ["MossAudioTokenizerConfig"]
weights/codec/model-00001-of-00001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34d9880d805eecb21bde975202b1c256dbd0eb98c8680b9d3aeffd2bc6ac2f67
3
+ size 87922568
weights/codec/model.safetensors.index.json ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 21969664,
4
+ "total_size": 87878656
5
+ },
6
+ "weight_map": {
7
+ "encoder.1.input_proj.weight": "model-00001-of-00001.safetensors",
8
+ "encoder.1.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
9
+ "encoder.1.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
10
+ "encoder.1.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
11
+ "encoder.1.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
12
+ "encoder.1.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
13
+ "encoder.1.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
14
+ "encoder.1.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
15
+ "encoder.1.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
16
+ "encoder.1.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
17
+ "encoder.1.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
18
+ "encoder.1.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
19
+ "encoder.1.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
20
+ "encoder.1.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
21
+ "encoder.1.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
22
+ "encoder.1.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
23
+ "encoder.1.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
24
+ "encoder.1.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
25
+ "encoder.1.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
26
+ "encoder.1.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
27
+ "encoder.1.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
28
+ "encoder.1.transformer.layers.2.norm1.weight": "model-00001-of-00001.safetensors",
29
+ "encoder.1.transformer.layers.2.norm1.bias": "model-00001-of-00001.safetensors",
30
+ "encoder.1.transformer.layers.2.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
31
+ "encoder.1.transformer.layers.2.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
32
+ "encoder.1.transformer.layers.2.norm2.weight": "model-00001-of-00001.safetensors",
33
+ "encoder.1.transformer.layers.2.norm2.bias": "model-00001-of-00001.safetensors",
34
+ "encoder.1.transformer.layers.2.ffn.0.weight": "model-00001-of-00001.safetensors",
35
+ "encoder.1.transformer.layers.2.ffn.2.weight": "model-00001-of-00001.safetensors",
36
+ "encoder.1.transformer.layers.2.layer_scale_1.scale": "model-00001-of-00001.safetensors",
37
+ "encoder.1.transformer.layers.2.layer_scale_2.scale": "model-00001-of-00001.safetensors",
38
+ "encoder.1.transformer.layers.3.norm1.weight": "model-00001-of-00001.safetensors",
39
+ "encoder.1.transformer.layers.3.norm1.bias": "model-00001-of-00001.safetensors",
40
+ "encoder.1.transformer.layers.3.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
41
+ "encoder.1.transformer.layers.3.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
42
+ "encoder.1.transformer.layers.3.norm2.weight": "model-00001-of-00001.safetensors",
43
+ "encoder.1.transformer.layers.3.norm2.bias": "model-00001-of-00001.safetensors",
44
+ "encoder.1.transformer.layers.3.ffn.0.weight": "model-00001-of-00001.safetensors",
45
+ "encoder.1.transformer.layers.3.ffn.2.weight": "model-00001-of-00001.safetensors",
46
+ "encoder.1.transformer.layers.3.layer_scale_1.scale": "model-00001-of-00001.safetensors",
47
+ "encoder.1.transformer.layers.3.layer_scale_2.scale": "model-00001-of-00001.safetensors",
48
+ "encoder.1.output_proj.weight": "model-00001-of-00001.safetensors",
49
+ "encoder.3.input_proj.weight": "model-00001-of-00001.safetensors",
50
+ "encoder.3.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
51
+ "encoder.3.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
52
+ "encoder.3.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
53
+ "encoder.3.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
54
+ "encoder.3.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
55
+ "encoder.3.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
56
+ "encoder.3.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
57
+ "encoder.3.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
58
+ "encoder.3.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
59
+ "encoder.3.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
60
+ "encoder.3.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
61
+ "encoder.3.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
62
+ "encoder.3.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
63
+ "encoder.3.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
64
+ "encoder.3.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
65
+ "encoder.3.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
66
+ "encoder.3.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
67
+ "encoder.3.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
68
+ "encoder.3.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
69
+ "encoder.3.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
70
+ "encoder.3.output_proj.weight": "model-00001-of-00001.safetensors",
71
+ "encoder.5.input_proj.weight": "model-00001-of-00001.safetensors",
72
+ "encoder.5.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
73
+ "encoder.5.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
74
+ "encoder.5.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
75
+ "encoder.5.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
76
+ "encoder.5.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
77
+ "encoder.5.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
78
+ "encoder.5.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
79
+ "encoder.5.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
80
+ "encoder.5.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
81
+ "encoder.5.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
82
+ "encoder.5.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
83
+ "encoder.5.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
84
+ "encoder.5.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
85
+ "encoder.5.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
86
+ "encoder.5.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
87
+ "encoder.5.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
88
+ "encoder.5.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
89
+ "encoder.5.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
90
+ "encoder.5.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
91
+ "encoder.5.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
92
+ "encoder.5.output_proj.weight": "model-00001-of-00001.safetensors",
93
+ "encoder.7.input_proj.weight": "model-00001-of-00001.safetensors",
94
+ "encoder.7.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
95
+ "encoder.7.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
96
+ "encoder.7.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
97
+ "encoder.7.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
98
+ "encoder.7.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
99
+ "encoder.7.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
100
+ "encoder.7.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
101
+ "encoder.7.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
102
+ "encoder.7.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
103
+ "encoder.7.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
104
+ "encoder.7.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
105
+ "encoder.7.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
106
+ "encoder.7.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
107
+ "encoder.7.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
108
+ "encoder.7.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
109
+ "encoder.7.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
110
+ "encoder.7.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
111
+ "encoder.7.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
112
+ "encoder.7.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
113
+ "encoder.7.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
114
+ "encoder.7.transformer.layers.2.norm1.weight": "model-00001-of-00001.safetensors",
115
+ "encoder.7.transformer.layers.2.norm1.bias": "model-00001-of-00001.safetensors",
116
+ "encoder.7.transformer.layers.2.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
117
+ "encoder.7.transformer.layers.2.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
118
+ "encoder.7.transformer.layers.2.norm2.weight": "model-00001-of-00001.safetensors",
119
+ "encoder.7.transformer.layers.2.norm2.bias": "model-00001-of-00001.safetensors",
120
+ "encoder.7.transformer.layers.2.ffn.0.weight": "model-00001-of-00001.safetensors",
121
+ "encoder.7.transformer.layers.2.ffn.2.weight": "model-00001-of-00001.safetensors",
122
+ "encoder.7.transformer.layers.2.layer_scale_1.scale": "model-00001-of-00001.safetensors",
123
+ "encoder.7.transformer.layers.2.layer_scale_2.scale": "model-00001-of-00001.safetensors",
124
+ "encoder.7.transformer.layers.3.norm1.weight": "model-00001-of-00001.safetensors",
125
+ "encoder.7.transformer.layers.3.norm1.bias": "model-00001-of-00001.safetensors",
126
+ "encoder.7.transformer.layers.3.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
127
+ "encoder.7.transformer.layers.3.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
128
+ "encoder.7.transformer.layers.3.norm2.weight": "model-00001-of-00001.safetensors",
129
+ "encoder.7.transformer.layers.3.norm2.bias": "model-00001-of-00001.safetensors",
130
+ "encoder.7.transformer.layers.3.ffn.0.weight": "model-00001-of-00001.safetensors",
131
+ "encoder.7.transformer.layers.3.ffn.2.weight": "model-00001-of-00001.safetensors",
132
+ "encoder.7.transformer.layers.3.layer_scale_1.scale": "model-00001-of-00001.safetensors",
133
+ "encoder.7.transformer.layers.3.layer_scale_2.scale": "model-00001-of-00001.safetensors",
134
+ "encoder.7.output_proj.weight": "model-00001-of-00001.safetensors",
135
+ "quantizer.input_proj.bias": "model-00001-of-00001.safetensors",
136
+ "quantizer.input_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
137
+ "quantizer.input_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
138
+ "quantizer.output_proj.bias": "model-00001-of-00001.safetensors",
139
+ "quantizer.output_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
140
+ "quantizer.output_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
141
+ "quantizer.quantizers.0.in_proj.bias": "model-00001-of-00001.safetensors",
142
+ "quantizer.quantizers.0.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
143
+ "quantizer.quantizers.0.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
144
+ "quantizer.quantizers.0.out_proj.bias": "model-00001-of-00001.safetensors",
145
+ "quantizer.quantizers.0.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
146
+ "quantizer.quantizers.0.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
147
+ "quantizer.quantizers.0.codebook.weight": "model-00001-of-00001.safetensors",
148
+ "quantizer.quantizers.1.in_proj.bias": "model-00001-of-00001.safetensors",
149
+ "quantizer.quantizers.1.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
150
+ "quantizer.quantizers.1.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
151
+ "quantizer.quantizers.1.out_proj.bias": "model-00001-of-00001.safetensors",
152
+ "quantizer.quantizers.1.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
153
+ "quantizer.quantizers.1.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
154
+ "quantizer.quantizers.1.codebook.weight": "model-00001-of-00001.safetensors",
155
+ "quantizer.quantizers.2.in_proj.bias": "model-00001-of-00001.safetensors",
156
+ "quantizer.quantizers.2.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
157
+ "quantizer.quantizers.2.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
158
+ "quantizer.quantizers.2.out_proj.bias": "model-00001-of-00001.safetensors",
159
+ "quantizer.quantizers.2.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
160
+ "quantizer.quantizers.2.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
161
+ "quantizer.quantizers.2.codebook.weight": "model-00001-of-00001.safetensors",
162
+ "quantizer.quantizers.3.in_proj.bias": "model-00001-of-00001.safetensors",
163
+ "quantizer.quantizers.3.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
164
+ "quantizer.quantizers.3.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
165
+ "quantizer.quantizers.3.out_proj.bias": "model-00001-of-00001.safetensors",
166
+ "quantizer.quantizers.3.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
167
+ "quantizer.quantizers.3.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
168
+ "quantizer.quantizers.3.codebook.weight": "model-00001-of-00001.safetensors",
169
+ "quantizer.quantizers.4.in_proj.bias": "model-00001-of-00001.safetensors",
170
+ "quantizer.quantizers.4.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
171
+ "quantizer.quantizers.4.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
172
+ "quantizer.quantizers.4.out_proj.bias": "model-00001-of-00001.safetensors",
173
+ "quantizer.quantizers.4.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
174
+ "quantizer.quantizers.4.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
175
+ "quantizer.quantizers.4.codebook.weight": "model-00001-of-00001.safetensors",
176
+ "quantizer.quantizers.5.in_proj.bias": "model-00001-of-00001.safetensors",
177
+ "quantizer.quantizers.5.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
178
+ "quantizer.quantizers.5.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
179
+ "quantizer.quantizers.5.out_proj.bias": "model-00001-of-00001.safetensors",
180
+ "quantizer.quantizers.5.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
181
+ "quantizer.quantizers.5.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
182
+ "quantizer.quantizers.5.codebook.weight": "model-00001-of-00001.safetensors",
183
+ "quantizer.quantizers.6.in_proj.bias": "model-00001-of-00001.safetensors",
184
+ "quantizer.quantizers.6.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
185
+ "quantizer.quantizers.6.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
186
+ "quantizer.quantizers.6.out_proj.bias": "model-00001-of-00001.safetensors",
187
+ "quantizer.quantizers.6.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
188
+ "quantizer.quantizers.6.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
189
+ "quantizer.quantizers.6.codebook.weight": "model-00001-of-00001.safetensors",
190
+ "quantizer.quantizers.7.in_proj.bias": "model-00001-of-00001.safetensors",
191
+ "quantizer.quantizers.7.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
192
+ "quantizer.quantizers.7.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
193
+ "quantizer.quantizers.7.out_proj.bias": "model-00001-of-00001.safetensors",
194
+ "quantizer.quantizers.7.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
195
+ "quantizer.quantizers.7.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
196
+ "quantizer.quantizers.7.codebook.weight": "model-00001-of-00001.safetensors",
197
+ "quantizer.quantizers.8.in_proj.bias": "model-00001-of-00001.safetensors",
198
+ "quantizer.quantizers.8.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
199
+ "quantizer.quantizers.8.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
200
+ "quantizer.quantizers.8.out_proj.bias": "model-00001-of-00001.safetensors",
201
+ "quantizer.quantizers.8.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
202
+ "quantizer.quantizers.8.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
203
+ "quantizer.quantizers.8.codebook.weight": "model-00001-of-00001.safetensors",
204
+ "quantizer.quantizers.9.in_proj.bias": "model-00001-of-00001.safetensors",
205
+ "quantizer.quantizers.9.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
206
+ "quantizer.quantizers.9.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
207
+ "quantizer.quantizers.9.out_proj.bias": "model-00001-of-00001.safetensors",
208
+ "quantizer.quantizers.9.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
209
+ "quantizer.quantizers.9.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
210
+ "quantizer.quantizers.9.codebook.weight": "model-00001-of-00001.safetensors",
211
+ "quantizer.quantizers.10.in_proj.bias": "model-00001-of-00001.safetensors",
212
+ "quantizer.quantizers.10.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
213
+ "quantizer.quantizers.10.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
214
+ "quantizer.quantizers.10.out_proj.bias": "model-00001-of-00001.safetensors",
215
+ "quantizer.quantizers.10.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
216
+ "quantizer.quantizers.10.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
217
+ "quantizer.quantizers.10.codebook.weight": "model-00001-of-00001.safetensors",
218
+ "quantizer.quantizers.11.in_proj.bias": "model-00001-of-00001.safetensors",
219
+ "quantizer.quantizers.11.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
220
+ "quantizer.quantizers.11.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
221
+ "quantizer.quantizers.11.out_proj.bias": "model-00001-of-00001.safetensors",
222
+ "quantizer.quantizers.11.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
223
+ "quantizer.quantizers.11.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
224
+ "quantizer.quantizers.11.codebook.weight": "model-00001-of-00001.safetensors",
225
+ "quantizer.quantizers.12.in_proj.bias": "model-00001-of-00001.safetensors",
226
+ "quantizer.quantizers.12.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
227
+ "quantizer.quantizers.12.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
228
+ "quantizer.quantizers.12.out_proj.bias": "model-00001-of-00001.safetensors",
229
+ "quantizer.quantizers.12.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
230
+ "quantizer.quantizers.12.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
231
+ "quantizer.quantizers.12.codebook.weight": "model-00001-of-00001.safetensors",
232
+ "quantizer.quantizers.13.in_proj.bias": "model-00001-of-00001.safetensors",
233
+ "quantizer.quantizers.13.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
234
+ "quantizer.quantizers.13.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
235
+ "quantizer.quantizers.13.out_proj.bias": "model-00001-of-00001.safetensors",
236
+ "quantizer.quantizers.13.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
237
+ "quantizer.quantizers.13.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
238
+ "quantizer.quantizers.13.codebook.weight": "model-00001-of-00001.safetensors",
239
+ "quantizer.quantizers.14.in_proj.bias": "model-00001-of-00001.safetensors",
240
+ "quantizer.quantizers.14.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
241
+ "quantizer.quantizers.14.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
242
+ "quantizer.quantizers.14.out_proj.bias": "model-00001-of-00001.safetensors",
243
+ "quantizer.quantizers.14.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
244
+ "quantizer.quantizers.14.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
245
+ "quantizer.quantizers.14.codebook.weight": "model-00001-of-00001.safetensors",
246
+ "quantizer.quantizers.15.in_proj.bias": "model-00001-of-00001.safetensors",
247
+ "quantizer.quantizers.15.in_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
248
+ "quantizer.quantizers.15.in_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
249
+ "quantizer.quantizers.15.out_proj.bias": "model-00001-of-00001.safetensors",
250
+ "quantizer.quantizers.15.out_proj.parametrizations.weight.original0": "model-00001-of-00001.safetensors",
251
+ "quantizer.quantizers.15.out_proj.parametrizations.weight.original1": "model-00001-of-00001.safetensors",
252
+ "quantizer.quantizers.15.codebook.weight": "model-00001-of-00001.safetensors",
253
+ "decoder.1.input_proj.weight": "model-00001-of-00001.safetensors",
254
+ "decoder.1.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
255
+ "decoder.1.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
256
+ "decoder.1.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
257
+ "decoder.1.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
258
+ "decoder.1.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
259
+ "decoder.1.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
260
+ "decoder.1.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
261
+ "decoder.1.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
262
+ "decoder.1.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
263
+ "decoder.1.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
264
+ "decoder.1.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
265
+ "decoder.1.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
266
+ "decoder.1.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
267
+ "decoder.1.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
268
+ "decoder.1.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
269
+ "decoder.1.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
270
+ "decoder.1.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
271
+ "decoder.1.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
272
+ "decoder.1.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
273
+ "decoder.1.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
274
+ "decoder.1.transformer.layers.2.norm1.weight": "model-00001-of-00001.safetensors",
275
+ "decoder.1.transformer.layers.2.norm1.bias": "model-00001-of-00001.safetensors",
276
+ "decoder.1.transformer.layers.2.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
277
+ "decoder.1.transformer.layers.2.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
278
+ "decoder.1.transformer.layers.2.norm2.weight": "model-00001-of-00001.safetensors",
279
+ "decoder.1.transformer.layers.2.norm2.bias": "model-00001-of-00001.safetensors",
280
+ "decoder.1.transformer.layers.2.ffn.0.weight": "model-00001-of-00001.safetensors",
281
+ "decoder.1.transformer.layers.2.ffn.2.weight": "model-00001-of-00001.safetensors",
282
+ "decoder.1.transformer.layers.2.layer_scale_1.scale": "model-00001-of-00001.safetensors",
283
+ "decoder.1.transformer.layers.2.layer_scale_2.scale": "model-00001-of-00001.safetensors",
284
+ "decoder.1.transformer.layers.3.norm1.weight": "model-00001-of-00001.safetensors",
285
+ "decoder.1.transformer.layers.3.norm1.bias": "model-00001-of-00001.safetensors",
286
+ "decoder.1.transformer.layers.3.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
287
+ "decoder.1.transformer.layers.3.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
288
+ "decoder.1.transformer.layers.3.norm2.weight": "model-00001-of-00001.safetensors",
289
+ "decoder.1.transformer.layers.3.norm2.bias": "model-00001-of-00001.safetensors",
290
+ "decoder.1.transformer.layers.3.ffn.0.weight": "model-00001-of-00001.safetensors",
291
+ "decoder.1.transformer.layers.3.ffn.2.weight": "model-00001-of-00001.safetensors",
292
+ "decoder.1.transformer.layers.3.layer_scale_1.scale": "model-00001-of-00001.safetensors",
293
+ "decoder.1.transformer.layers.3.layer_scale_2.scale": "model-00001-of-00001.safetensors",
294
+ "decoder.1.output_proj.weight": "model-00001-of-00001.safetensors",
295
+ "decoder.3.input_proj.weight": "model-00001-of-00001.safetensors",
296
+ "decoder.3.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
297
+ "decoder.3.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
298
+ "decoder.3.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
299
+ "decoder.3.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
300
+ "decoder.3.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
301
+ "decoder.3.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
302
+ "decoder.3.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
303
+ "decoder.3.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
304
+ "decoder.3.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
305
+ "decoder.3.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
306
+ "decoder.3.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
307
+ "decoder.3.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
308
+ "decoder.3.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
309
+ "decoder.3.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
310
+ "decoder.3.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
311
+ "decoder.3.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
312
+ "decoder.3.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
313
+ "decoder.3.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
314
+ "decoder.3.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
315
+ "decoder.3.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
316
+ "decoder.3.output_proj.weight": "model-00001-of-00001.safetensors",
317
+ "decoder.5.input_proj.weight": "model-00001-of-00001.safetensors",
318
+ "decoder.5.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
319
+ "decoder.5.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
320
+ "decoder.5.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
321
+ "decoder.5.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
322
+ "decoder.5.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
323
+ "decoder.5.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
324
+ "decoder.5.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
325
+ "decoder.5.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
326
+ "decoder.5.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
327
+ "decoder.5.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
328
+ "decoder.5.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
329
+ "decoder.5.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
330
+ "decoder.5.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
331
+ "decoder.5.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
332
+ "decoder.5.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
333
+ "decoder.5.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
334
+ "decoder.5.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
335
+ "decoder.5.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
336
+ "decoder.5.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
337
+ "decoder.5.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
338
+ "decoder.5.output_proj.weight": "model-00001-of-00001.safetensors",
339
+ "decoder.7.input_proj.weight": "model-00001-of-00001.safetensors",
340
+ "decoder.7.transformer.layers.0.norm1.weight": "model-00001-of-00001.safetensors",
341
+ "decoder.7.transformer.layers.0.norm1.bias": "model-00001-of-00001.safetensors",
342
+ "decoder.7.transformer.layers.0.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
343
+ "decoder.7.transformer.layers.0.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
344
+ "decoder.7.transformer.layers.0.norm2.weight": "model-00001-of-00001.safetensors",
345
+ "decoder.7.transformer.layers.0.norm2.bias": "model-00001-of-00001.safetensors",
346
+ "decoder.7.transformer.layers.0.ffn.0.weight": "model-00001-of-00001.safetensors",
347
+ "decoder.7.transformer.layers.0.ffn.2.weight": "model-00001-of-00001.safetensors",
348
+ "decoder.7.transformer.layers.0.layer_scale_1.scale": "model-00001-of-00001.safetensors",
349
+ "decoder.7.transformer.layers.0.layer_scale_2.scale": "model-00001-of-00001.safetensors",
350
+ "decoder.7.transformer.layers.1.norm1.weight": "model-00001-of-00001.safetensors",
351
+ "decoder.7.transformer.layers.1.norm1.bias": "model-00001-of-00001.safetensors",
352
+ "decoder.7.transformer.layers.1.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
353
+ "decoder.7.transformer.layers.1.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
354
+ "decoder.7.transformer.layers.1.norm2.weight": "model-00001-of-00001.safetensors",
355
+ "decoder.7.transformer.layers.1.norm2.bias": "model-00001-of-00001.safetensors",
356
+ "decoder.7.transformer.layers.1.ffn.0.weight": "model-00001-of-00001.safetensors",
357
+ "decoder.7.transformer.layers.1.ffn.2.weight": "model-00001-of-00001.safetensors",
358
+ "decoder.7.transformer.layers.1.layer_scale_1.scale": "model-00001-of-00001.safetensors",
359
+ "decoder.7.transformer.layers.1.layer_scale_2.scale": "model-00001-of-00001.safetensors",
360
+ "decoder.7.transformer.layers.2.norm1.weight": "model-00001-of-00001.safetensors",
361
+ "decoder.7.transformer.layers.2.norm1.bias": "model-00001-of-00001.safetensors",
362
+ "decoder.7.transformer.layers.2.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
363
+ "decoder.7.transformer.layers.2.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
364
+ "decoder.7.transformer.layers.2.norm2.weight": "model-00001-of-00001.safetensors",
365
+ "decoder.7.transformer.layers.2.norm2.bias": "model-00001-of-00001.safetensors",
366
+ "decoder.7.transformer.layers.2.ffn.0.weight": "model-00001-of-00001.safetensors",
367
+ "decoder.7.transformer.layers.2.ffn.2.weight": "model-00001-of-00001.safetensors",
368
+ "decoder.7.transformer.layers.2.layer_scale_1.scale": "model-00001-of-00001.safetensors",
369
+ "decoder.7.transformer.layers.2.layer_scale_2.scale": "model-00001-of-00001.safetensors",
370
+ "decoder.7.transformer.layers.3.norm1.weight": "model-00001-of-00001.safetensors",
371
+ "decoder.7.transformer.layers.3.norm1.bias": "model-00001-of-00001.safetensors",
372
+ "decoder.7.transformer.layers.3.self_attn.in_proj.weight": "model-00001-of-00001.safetensors",
373
+ "decoder.7.transformer.layers.3.self_attn.out_proj.weight": "model-00001-of-00001.safetensors",
374
+ "decoder.7.transformer.layers.3.norm2.weight": "model-00001-of-00001.safetensors",
375
+ "decoder.7.transformer.layers.3.norm2.bias": "model-00001-of-00001.safetensors",
376
+ "decoder.7.transformer.layers.3.ffn.0.weight": "model-00001-of-00001.safetensors",
377
+ "decoder.7.transformer.layers.3.ffn.2.weight": "model-00001-of-00001.safetensors",
378
+ "decoder.7.transformer.layers.3.layer_scale_1.scale": "model-00001-of-00001.safetensors",
379
+ "decoder.7.transformer.layers.3.layer_scale_2.scale": "model-00001-of-00001.safetensors",
380
+ "decoder.7.output_proj.weight": "model-00001-of-00001.safetensors"
381
+ }
382
+ }
weights/codec/modeling_moss_audio_tokenizer.py ADDED
The diff for this file is too large to render. See raw diff
 
weights/tts/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
weights/tts/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
weights/tts/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration_nanotts import NanoTTSConfig
2
+ from .modeling_nanotts_global_local import (
3
+ NanoTTSGenerationOutput,
4
+ NanoTTSGlobalLocalForCausalLM,
5
+ NanoTTSOutput,
6
+ )
7
+ from .tokenization_nanotts_sentencepiece import NanoTTSSentencePieceTokenizer
8
+
9
+ try:
10
+ NanoTTSConfig.register_for_auto_class()
11
+ except Exception:
12
+ pass
13
+
14
+ for auto_class_name in ("AutoModel", "AutoModelForCausalLM"):
15
+ try:
16
+ NanoTTSGlobalLocalForCausalLM.register_for_auto_class(auto_class_name)
17
+ except Exception:
18
+ pass
19
+
20
+ try:
21
+ NanoTTSSentencePieceTokenizer.register_for_auto_class("AutoTokenizer")
22
+ except Exception:
23
+ pass
24
+
25
+ __all__ = [
26
+ "NanoTTSConfig",
27
+ "NanoTTSGlobalLocalForCausalLM",
28
+ "NanoTTSSentencePieceTokenizer",
29
+ "NanoTTSGenerationOutput",
30
+ "NanoTTSOutput",
31
+ ]
weights/tts/config.json ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_cross_attention": false,
3
+ "architectures": [
4
+ "NanoTTSGlobalLocalForCausalLM"
5
+ ],
6
+ "attn_implementation": "sdpa",
7
+ "audio_assistant_slot_token_id": 9,
8
+ "audio_codebook_sizes": [
9
+ 1024,
10
+ 1024,
11
+ 1024,
12
+ 1024,
13
+ 1024,
14
+ 1024,
15
+ 1024,
16
+ 1024,
17
+ 1024,
18
+ 1024,
19
+ 1024,
20
+ 1024,
21
+ 1024,
22
+ 1024,
23
+ 1024,
24
+ 1024
25
+ ],
26
+ "audio_end_token_id": 7,
27
+ "audio_pad_token_id": 1024,
28
+ "audio_start_token_id": 6,
29
+ "audio_tokenizer_pretrained_name_or_path": "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano",
30
+ "audio_tokenizer_sample_rate": 48000,
31
+ "audio_tokenizer_type": "moss-audio-tokenizer-nano",
32
+ "audio_user_slot_token_id": 8,
33
+ "audio_vocab_size": 1024,
34
+ "bad_words_ids": null,
35
+ "begin_suppress_tokens": null,
36
+ "bos_token_id": null,
37
+ "chunk_size_feed_forward": 0,
38
+ "cross_attention_hidden_size": null,
39
+ "decoder_start_token_id": null,
40
+ "diversity_penalty": 0.0,
41
+ "do_sample": false,
42
+ "dtype": "float32",
43
+ "early_stopping": false,
44
+ "encoder_no_repeat_ngram_size": 0,
45
+ "eos_token_id": null,
46
+ "exponential_decay_length_penalty": null,
47
+ "finetuning_task": null,
48
+ "forced_bos_token_id": null,
49
+ "forced_eos_token_id": null,
50
+ "gpt2_config": {
51
+ "_name_or_path": "",
52
+ "activation_function": "gelu_new",
53
+ "add_cross_attention": false,
54
+ "architectures": null,
55
+ "attn_pdrop": 0.0,
56
+ "bad_words_ids": null,
57
+ "begin_suppress_tokens": null,
58
+ "bos_token_id": 1,
59
+ "chunk_size_feed_forward": 0,
60
+ "cross_attention_hidden_size": null,
61
+ "decoder_start_token_id": null,
62
+ "diversity_penalty": 0.0,
63
+ "do_sample": false,
64
+ "dtype": null,
65
+ "early_stopping": false,
66
+ "embd_pdrop": 0.0,
67
+ "encoder_no_repeat_ngram_size": 0,
68
+ "eos_token_id": 2,
69
+ "exponential_decay_length_penalty": null,
70
+ "finetuning_task": null,
71
+ "forced_bos_token_id": null,
72
+ "forced_eos_token_id": null,
73
+ "id2label": {
74
+ "0": "LABEL_0",
75
+ "1": "LABEL_1"
76
+ },
77
+ "initializer_range": 0.02,
78
+ "is_decoder": false,
79
+ "is_encoder_decoder": false,
80
+ "label2id": {
81
+ "LABEL_0": 0,
82
+ "LABEL_1": 1
83
+ },
84
+ "layer_norm_epsilon": 1e-05,
85
+ "length_penalty": 1.0,
86
+ "max_length": 20,
87
+ "min_length": 0,
88
+ "model_type": "gpt2",
89
+ "n_ctx": 32768,
90
+ "n_embd": 768,
91
+ "n_head": 12,
92
+ "n_inner": 3072,
93
+ "n_layer": 12,
94
+ "n_positions": 32768,
95
+ "no_repeat_ngram_size": 0,
96
+ "num_beam_groups": 1,
97
+ "num_beams": 1,
98
+ "num_return_sequences": 1,
99
+ "output_attentions": false,
100
+ "output_hidden_states": false,
101
+ "output_scores": false,
102
+ "pad_token_id": 3,
103
+ "position_embedding_type": "rope",
104
+ "prefix": null,
105
+ "problem_type": null,
106
+ "pruned_heads": {},
107
+ "remove_invalid_values": false,
108
+ "reorder_and_upcast_attn": false,
109
+ "repetition_penalty": 1.0,
110
+ "resid_pdrop": 0.0,
111
+ "return_dict": true,
112
+ "return_dict_in_generate": false,
113
+ "rope_base": 10000.0,
114
+ "scale_attn_by_inverse_layer_idx": false,
115
+ "scale_attn_weights": true,
116
+ "sep_token_id": null,
117
+ "summary_activation": null,
118
+ "summary_first_dropout": 0.1,
119
+ "summary_proj_to_labels": true,
120
+ "summary_type": "cls_index",
121
+ "summary_use_proj": true,
122
+ "suppress_tokens": null,
123
+ "task_specific_params": null,
124
+ "temperature": 1.0,
125
+ "tf_legacy_loss": false,
126
+ "tie_encoder_decoder": false,
127
+ "tie_word_embeddings": true,
128
+ "tokenizer_class": null,
129
+ "top_k": 50,
130
+ "top_p": 1.0,
131
+ "torchscript": false,
132
+ "transformers_version": "4.57.1",
133
+ "typical_p": 1.0,
134
+ "use_bfloat16": false,
135
+ "use_cache": true,
136
+ "vocab_size": 16384
137
+ },
138
+ "hidden_size": 768,
139
+ "id2label": {
140
+ "0": "LABEL_0",
141
+ "1": "LABEL_1"
142
+ },
143
+ "im_end_token_id": 5,
144
+ "im_start_token_id": 4,
145
+ "initializer_range": 0.02,
146
+ "is_decoder": false,
147
+ "is_encoder_decoder": false,
148
+ "label2id": {
149
+ "LABEL_0": 0,
150
+ "LABEL_1": 1
151
+ },
152
+ "length_penalty": 1.0,
153
+ "local_transformer_attn_implementation": "sdpa",
154
+ "local_transformer_layers": 1,
155
+ "max_length": 20,
156
+ "max_position_embeddings": 32768,
157
+ "min_length": 0,
158
+ "model_architecture": "global_local_transformer",
159
+ "model_type": "nano_tts",
160
+ "n_vq": 16,
161
+ "no_repeat_ngram_size": 0,
162
+ "num_beam_groups": 1,
163
+ "num_beams": 1,
164
+ "num_return_sequences": 1,
165
+ "output_attentions": false,
166
+ "output_hidden_states": false,
167
+ "output_scores": false,
168
+ "pad_token_id": 3,
169
+ "prefix": null,
170
+ "problem_type": null,
171
+ "pruned_heads": {},
172
+ "remove_invalid_values": false,
173
+ "repetition_penalty": 1.0,
174
+ "return_dict": true,
175
+ "return_dict_in_generate": false,
176
+ "sep_token_id": null,
177
+ "suppress_tokens": null,
178
+ "task_specific_params": null,
179
+ "temperature": 1.0,
180
+ "tf_legacy_loss": false,
181
+ "tie_encoder_decoder": false,
182
+ "tie_word_embeddings": true,
183
+ "tokenizer_class": "NanoTTSSentencePieceTokenizer",
184
+ "tokenizer_use_fast": false,
185
+ "top_k": 50,
186
+ "top_p": 1.0,
187
+ "torchscript": false,
188
+ "transformers_version": "4.57.1",
189
+ "typical_p": 1.0,
190
+ "use_bfloat16": false,
191
+ "vocab_size": 16384,
192
+ "auto_map": {
193
+ "AutoConfig": "configuration_nanotts.NanoTTSConfig",
194
+ "AutoModel": "modeling_nanotts_global_local.NanoTTSGlobalLocalForCausalLM",
195
+ "AutoModelForCausalLM": "modeling_nanotts_global_local.NanoTTSGlobalLocalForCausalLM"
196
+ }
197
+ }
weights/tts/configuration_nanotts.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import Any, Dict, Optional, Union
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
6
+
7
+
8
+ class NanoTTSConfig(PretrainedConfig):
9
+ model_type = "nano_tts"
10
+ keys_to_ignore_at_inference = ["past_key_values"]
11
+
12
+ def __init__(
13
+ self,
14
+ gpt2_config: Optional[Union[GPT2Config, Dict[str, Any]]] = None,
15
+ n_vq: int = 8,
16
+ audio_vocab_size: Optional[int] = 1024,
17
+ audio_codebook_sizes: Optional[list[int]] = None,
18
+ audio_pad_token_id: int = 1024,
19
+ pad_token_id: int = 151643,
20
+ im_start_token_id: int = 151644,
21
+ im_end_token_id: int = 151645,
22
+ audio_start_token_id: int = 151652,
23
+ audio_end_token_id: int = 151653,
24
+ audio_user_slot_token_id: int = 151654,
25
+ audio_assistant_slot_token_id: int = 151656,
26
+ tokenizer_use_fast: bool = False,
27
+ audio_tokenizer_type: str = "moss-audio-tokenizer-nano",
28
+ audio_tokenizer_pretrained_name_or_path: Optional[str] = "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano",
29
+ audio_tokenizer_sample_rate: int = 48000,
30
+ attn_implementation: str = "flash_attention_2",
31
+ initializer_range: float = 0.02,
32
+ model_architecture: str = "global_local_transformer",
33
+ local_transformer_layers: int = 4,
34
+ local_transformer_attn_implementation: Optional[str] = None,
35
+ **kwargs: Any,
36
+ ) -> None:
37
+ if isinstance(gpt2_config, dict):
38
+ self.gpt2_config = GPT2Config(**gpt2_config)
39
+ elif gpt2_config is None:
40
+ self.gpt2_config = GPT2Config()
41
+ else:
42
+ self.gpt2_config = gpt2_config
43
+
44
+ self.n_vq = int(n_vq)
45
+ if audio_codebook_sizes is None:
46
+ if audio_vocab_size is None:
47
+ raise ValueError("audio_vocab_size must be set when audio_codebook_sizes is not provided.")
48
+ resolved_audio_codebook_sizes = [int(audio_vocab_size)] * self.n_vq
49
+ else:
50
+ resolved_audio_codebook_sizes = [int(codebook_size) for codebook_size in audio_codebook_sizes]
51
+ if len(resolved_audio_codebook_sizes) != self.n_vq:
52
+ raise ValueError(
53
+ "audio_codebook_sizes must have length n_vq "
54
+ f"(expected {self.n_vq}, got {len(resolved_audio_codebook_sizes)})."
55
+ )
56
+ if any(codebook_size <= 0 for codebook_size in resolved_audio_codebook_sizes):
57
+ raise ValueError("audio_codebook_sizes must contain positive integers.")
58
+
59
+ max_audio_codebook_size = max(resolved_audio_codebook_sizes)
60
+ if audio_vocab_size is not None and int(audio_vocab_size) < max_audio_codebook_size:
61
+ raise ValueError(
62
+ "audio_vocab_size must be >= max(audio_codebook_sizes) "
63
+ f"(got {audio_vocab_size}, expected at least {max_audio_codebook_size})."
64
+ )
65
+
66
+ self.audio_codebook_sizes = resolved_audio_codebook_sizes
67
+ self.audio_vocab_size = (
68
+ max_audio_codebook_size if audio_vocab_size is None else int(audio_vocab_size)
69
+ )
70
+ self.audio_pad_token_id = int(audio_pad_token_id)
71
+ if self.audio_pad_token_id < max_audio_codebook_size:
72
+ raise ValueError(
73
+ "audio_pad_token_id must be >= max(audio_codebook_sizes) so pad stays outside every codebook "
74
+ f"(got {self.audio_pad_token_id}, max codebook size {max_audio_codebook_size})."
75
+ )
76
+ self.pad_token_id = pad_token_id
77
+ self.im_start_token_id = im_start_token_id
78
+ self.im_end_token_id = im_end_token_id
79
+ self.audio_start_token_id = audio_start_token_id
80
+ self.audio_end_token_id = audio_end_token_id
81
+ self.audio_user_slot_token_id = audio_user_slot_token_id
82
+ self.audio_assistant_slot_token_id = audio_assistant_slot_token_id
83
+ self.tokenizer_use_fast = tokenizer_use_fast
84
+ self.audio_tokenizer_type = audio_tokenizer_type
85
+ self.audio_tokenizer_pretrained_name_or_path = audio_tokenizer_pretrained_name_or_path
86
+ self.audio_tokenizer_sample_rate = audio_tokenizer_sample_rate
87
+ self.attn_implementation = attn_implementation
88
+ self.initializer_range = initializer_range
89
+ self.model_architecture = model_architecture
90
+ self.local_transformer_layers = local_transformer_layers
91
+ self.local_transformer_attn_implementation = (
92
+ attn_implementation
93
+ if local_transformer_attn_implementation is None
94
+ else local_transformer_attn_implementation
95
+ )
96
+ self.vocab_size = self.gpt2_config.vocab_size
97
+ self.hidden_size = self.gpt2_config.hidden_size
98
+ self.max_position_embeddings = self.gpt2_config.n_positions
99
+
100
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
101
+
102
+ def to_dict(self) -> Dict[str, Any]:
103
+ output = super().to_dict()
104
+ output["gpt2_config"] = self.gpt2_config.to_dict()
105
+ return output
weights/tts/gpt2_decoder.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.utils.checkpoint
10
+ from transformers.activations import ACT2FN
11
+ from transformers.modeling_outputs import BaseModelOutputWithPast
12
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
13
+
14
+ try:
15
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
16
+ from flash_attn.bert_padding import pad_input, unpad_input
17
+
18
+ _FLASH_ATTN_AVAILABLE = True
19
+ except Exception:
20
+ flash_attn_func = None
21
+ flash_attn_varlen_func = None
22
+ pad_input = None
23
+ unpad_input = None
24
+ _FLASH_ATTN_AVAILABLE = False
25
+
26
+
27
+ @dataclass
28
+ class PackedSequenceMetadata:
29
+ cu_seqlens: torch.Tensor
30
+ max_seqlen: int
31
+ indices: Optional[torch.Tensor] = None
32
+ batch_size: Optional[int] = None
33
+ seq_len: Optional[int] = None
34
+
35
+
36
+ class NanoGPT2RotaryEmbedding(nn.Module):
37
+ def __init__(self, dim: int, base: float = 10000.0) -> None:
38
+ super().__init__()
39
+ if dim % 2 != 0:
40
+ raise ValueError(f"RoPE head_dim must be even, got {dim}")
41
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
42
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
43
+
44
+ def forward(
45
+ self,
46
+ position_ids: torch.LongTensor,
47
+ *,
48
+ device: torch.device,
49
+ dtype: torch.dtype,
50
+ ) -> tuple[torch.Tensor, torch.Tensor]:
51
+ if position_ids.ndim == 1:
52
+ position_ids = position_ids.unsqueeze(0)
53
+ freqs = torch.einsum("bs,d->bsd", position_ids.to(device=device, dtype=self.inv_freq.dtype), self.inv_freq)
54
+ cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
55
+ sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
56
+ return cos, sin
57
+
58
+
59
+ def rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
60
+ even = hidden_states[..., ::2]
61
+ odd = hidden_states[..., 1::2]
62
+ return torch.stack((-odd, even), dim=-1).reshape_as(hidden_states)
63
+
64
+
65
+ def apply_rotary_pos_emb(
66
+ hidden_states: torch.Tensor,
67
+ cos: torch.Tensor,
68
+ sin: torch.Tensor,
69
+ ) -> torch.Tensor:
70
+ return (hidden_states * cos) + (rotate_half(hidden_states) * sin)
71
+
72
+
73
+ class NanoGPT2MLP(nn.Module):
74
+ def __init__(self, config: GPT2Config) -> None:
75
+ super().__init__()
76
+ hidden_size = int(config.hidden_size)
77
+ inner_size = int(config.n_inner or 4 * hidden_size)
78
+ self.fc_in = nn.Linear(hidden_size, inner_size)
79
+ self.fc_out = nn.Linear(inner_size, hidden_size)
80
+ self.act = ACT2FN[config.activation_function]
81
+ self.dropout = nn.Dropout(config.resid_pdrop)
82
+
83
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
84
+ hidden_states = self.fc_in(hidden_states)
85
+ hidden_states = self.act(hidden_states)
86
+ hidden_states = self.fc_out(hidden_states)
87
+ return self.dropout(hidden_states)
88
+
89
+
90
+ class NanoGPT2Attention(nn.Module):
91
+ def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
92
+ super().__init__()
93
+ hidden_size = int(config.hidden_size)
94
+ num_heads = int(config.num_attention_heads)
95
+ if hidden_size % num_heads != 0:
96
+ raise ValueError(f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_heads}")
97
+
98
+ self.num_heads = num_heads
99
+ self.head_dim = hidden_size // num_heads
100
+ self.embed_dim = hidden_size
101
+ self.layer_idx = layer_idx
102
+ self.attn_implementation = attn_implementation
103
+ self.attn_dropout = float(config.attn_pdrop)
104
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
105
+ self.scale_attn_weights = bool(getattr(config, "scale_attn_weights", True))
106
+ self.scale_attn_by_inverse_layer_idx = bool(getattr(config, "scale_attn_by_inverse_layer_idx", False))
107
+ self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
108
+ if self.position_embedding_type not in {"absolute", "rope"}:
109
+ raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
110
+
111
+ self.c_attn = nn.Linear(hidden_size, 3 * hidden_size)
112
+ self.c_proj = nn.Linear(hidden_size, hidden_size)
113
+ self.rotary_emb = None
114
+ if self.position_embedding_type == "rope":
115
+ self.rotary_emb = NanoGPT2RotaryEmbedding(
116
+ self.head_dim,
117
+ base=float(getattr(config, "rope_base", 10000.0)),
118
+ )
119
+
120
+ def _split_heads(self, tensor: torch.Tensor) -> torch.Tensor:
121
+ if tensor.ndim == 3:
122
+ batch_size, seq_len, _ = tensor.shape
123
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim)
124
+ if tensor.ndim == 2:
125
+ total_tokens, _ = tensor.shape
126
+ return tensor.view(total_tokens, self.num_heads, self.head_dim)
127
+ raise ValueError(f"Unsupported tensor rank for attention split: {tensor.ndim}")
128
+
129
+ def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
130
+ if tensor.ndim == 4:
131
+ batch_size, seq_len, _, _ = tensor.shape
132
+ return tensor.reshape(batch_size, seq_len, self.embed_dim)
133
+ if tensor.ndim == 3:
134
+ total_tokens, _, _ = tensor.shape
135
+ return tensor.reshape(total_tokens, self.embed_dim)
136
+ raise ValueError(f"Unsupported tensor rank for attention merge: {tensor.ndim}")
137
+
138
+ def _causal_attention_mask(
139
+ self,
140
+ attention_mask: Optional[torch.Tensor],
141
+ query_length: int,
142
+ key_length: int,
143
+ device: torch.device,
144
+ ) -> torch.Tensor:
145
+ query_positions = torch.arange(query_length, device=device, dtype=torch.long)
146
+ query_positions = query_positions + max(key_length - query_length, 0)
147
+ key_positions = torch.arange(key_length, device=device, dtype=torch.long)
148
+ causal = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
149
+ causal = causal.unsqueeze(0).unsqueeze(0)
150
+ if attention_mask is None:
151
+ return causal
152
+ key_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
153
+ return causal & key_mask
154
+
155
+ def _eager_attention(
156
+ self,
157
+ query: torch.Tensor,
158
+ key: torch.Tensor,
159
+ value: torch.Tensor,
160
+ attention_mask: Optional[torch.Tensor],
161
+ ) -> torch.Tensor:
162
+ query = query.transpose(1, 2)
163
+ key = key.transpose(1, 2)
164
+ value = value.transpose(1, 2)
165
+
166
+ scale = 1.0
167
+ if self.scale_attn_weights:
168
+ scale /= self.head_dim ** 0.5
169
+ if self.scale_attn_by_inverse_layer_idx:
170
+ scale /= float(self.layer_idx + 1)
171
+
172
+ scores = torch.matmul(query, key.transpose(-1, -2)) * scale
173
+ causal_mask = self._causal_attention_mask(
174
+ attention_mask=attention_mask,
175
+ query_length=query.shape[-2],
176
+ key_length=key.shape[-2],
177
+ device=query.device,
178
+ )
179
+ scores = scores.masked_fill(~causal_mask, torch.finfo(scores.dtype).min)
180
+ probs = torch.softmax(scores, dim=-1)
181
+ if self.training and self.attn_dropout > 0:
182
+ probs = torch.dropout(probs, self.attn_dropout, train=True)
183
+ output = torch.matmul(probs, value)
184
+ return output.transpose(1, 2).contiguous()
185
+
186
+ def _sdpa_attention(
187
+ self,
188
+ query: torch.Tensor,
189
+ key: torch.Tensor,
190
+ value: torch.Tensor,
191
+ attention_mask: Optional[torch.Tensor],
192
+ ) -> torch.Tensor:
193
+ query = query.transpose(1, 2)
194
+ key = key.transpose(1, 2)
195
+ value = value.transpose(1, 2)
196
+ mask = None
197
+ query_attention_mask = None
198
+ if attention_mask is not None:
199
+ query_length = query.shape[-2]
200
+ key_length = key.shape[-2]
201
+ mask = self._causal_attention_mask(
202
+ attention_mask=attention_mask,
203
+ query_length=query_length,
204
+ key_length=key_length,
205
+ device=query.device,
206
+ )
207
+ query_attention_mask = attention_mask[:, -query_length:].to(dtype=torch.bool, device=query.device)
208
+ if not bool(query_attention_mask.all()):
209
+ # SDPA can produce NaNs when a query row is fully masked. For padded query positions,
210
+ # keep a single aligned key visible, then zero the query output after attention.
211
+ mask = mask.expand(query.shape[0], -1, -1, -1).clone()
212
+ invalid_batch, invalid_query = torch.nonzero(~query_attention_mask, as_tuple=True)
213
+ aligned_key = invalid_query + max(key_length - query_length, 0)
214
+ mask[invalid_batch, :, invalid_query, aligned_key] = True
215
+ output = torch.nn.functional.scaled_dot_product_attention(
216
+ query,
217
+ key,
218
+ value,
219
+ attn_mask=mask,
220
+ dropout_p=self.attn_dropout if self.training else 0.0,
221
+ is_causal=mask is None,
222
+ )
223
+ if query_attention_mask is not None and not bool(query_attention_mask.all()):
224
+ output = output.masked_fill(~query_attention_mask[:, None, :, None], 0.0)
225
+ return output.transpose(1, 2).contiguous()
226
+
227
+ def _flash_attention(
228
+ self,
229
+ query: torch.Tensor,
230
+ key: torch.Tensor,
231
+ value: torch.Tensor,
232
+ attention_mask: Optional[torch.Tensor],
233
+ packed_metadata: Optional[PackedSequenceMetadata],
234
+ ) -> torch.Tensor:
235
+ if not _FLASH_ATTN_AVAILABLE:
236
+ raise ImportError("flash_attn is not installed, but attn_implementation='flash_attention_2' was requested.")
237
+ if query.device.type != "cuda":
238
+ raise ValueError("flash_attention_2 requires CUDA tensors.")
239
+ if query.dtype not in (torch.float16, torch.bfloat16):
240
+ raise ValueError(
241
+ f"flash_attention_2 requires fp16/bf16 tensors, but received dtype={query.dtype}."
242
+ )
243
+
244
+ dropout_p = self.attn_dropout if self.training else 0.0
245
+ if packed_metadata is not None:
246
+ if packed_metadata.indices is not None:
247
+ query = query.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
248
+ key = key.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
249
+ value = value.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
250
+ output = flash_attn_varlen_func(
251
+ query,
252
+ key,
253
+ value,
254
+ packed_metadata.cu_seqlens,
255
+ packed_metadata.cu_seqlens,
256
+ packed_metadata.max_seqlen,
257
+ packed_metadata.max_seqlen,
258
+ dropout_p=dropout_p,
259
+ causal=True,
260
+ )
261
+ if packed_metadata.indices is None:
262
+ return output
263
+ return pad_input(
264
+ output,
265
+ packed_metadata.indices,
266
+ packed_metadata.batch_size,
267
+ packed_metadata.seq_len,
268
+ )
269
+
270
+ if attention_mask is None or bool(attention_mask.all()):
271
+ return flash_attn_func(
272
+ query,
273
+ key,
274
+ value,
275
+ dropout_p=dropout_p,
276
+ causal=True,
277
+ )
278
+
279
+ unpadded_query, indices, cu_seqlens, max_seqlen, _ = unpad_input(query, attention_mask)
280
+ unpadded_key, _, _, _, _ = unpad_input(key, attention_mask)
281
+ unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
282
+ output = flash_attn_varlen_func(
283
+ unpadded_query,
284
+ unpadded_key,
285
+ unpadded_value,
286
+ cu_seqlens,
287
+ cu_seqlens,
288
+ max_seqlen,
289
+ max_seqlen,
290
+ dropout_p=dropout_p,
291
+ causal=True,
292
+ )
293
+ return pad_input(output, indices, query.shape[0], query.shape[1])
294
+
295
+ def forward(
296
+ self,
297
+ hidden_states: torch.Tensor,
298
+ attention_mask: Optional[torch.Tensor] = None,
299
+ position_ids: Optional[torch.LongTensor] = None,
300
+ packed_metadata: Optional[PackedSequenceMetadata] = None,
301
+ layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
302
+ use_cache: bool = False,
303
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
304
+ qkv = self.c_attn(hidden_states)
305
+ query, key, value = qkv.split(self.embed_dim, dim=-1)
306
+ query = self._split_heads(query)
307
+ key = self._split_heads(key)
308
+ value = self._split_heads(value)
309
+
310
+ if self.rotary_emb is not None:
311
+ if position_ids is None:
312
+ raise ValueError("position_ids must be provided when position_embedding_type='rope'.")
313
+ cos, sin = self.rotary_emb(
314
+ position_ids.to(device=query.device),
315
+ device=query.device,
316
+ dtype=query.dtype,
317
+ )
318
+ query = apply_rotary_pos_emb(query, cos, sin)
319
+ key = apply_rotary_pos_emb(key, cos, sin)
320
+
321
+ if layer_past is not None:
322
+ past_key, past_value = layer_past
323
+ key = torch.cat([past_key.to(device=key.device, dtype=key.dtype), key], dim=1)
324
+ value = torch.cat([past_value.to(device=value.device, dtype=value.dtype), value], dim=1)
325
+
326
+ present = (key, value) if use_cache else None
327
+
328
+ if self.attn_implementation == "flash_attention_2" and layer_past is None:
329
+ attn_output = self._flash_attention(
330
+ query=query,
331
+ key=key,
332
+ value=value,
333
+ attention_mask=attention_mask,
334
+ packed_metadata=packed_metadata,
335
+ )
336
+ elif self.attn_implementation == "sdpa":
337
+ attn_output = self._sdpa_attention(
338
+ query=query,
339
+ key=key,
340
+ value=value,
341
+ attention_mask=attention_mask,
342
+ )
343
+ else:
344
+ attn_output = self._eager_attention(
345
+ query=query,
346
+ key=key,
347
+ value=value,
348
+ attention_mask=attention_mask,
349
+ )
350
+
351
+ attn_output = self._merge_heads(attn_output)
352
+ attn_output = self.c_proj(attn_output)
353
+ return self.resid_dropout(attn_output), present
354
+
355
+
356
+ class NanoGPT2Block(nn.Module):
357
+ def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
358
+ super().__init__()
359
+ hidden_size = int(config.hidden_size)
360
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
361
+ self.attn = NanoGPT2Attention(config, layer_idx=layer_idx, attn_implementation=attn_implementation)
362
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
363
+ self.mlp = NanoGPT2MLP(config)
364
+
365
+ def forward(
366
+ self,
367
+ hidden_states: torch.Tensor,
368
+ attention_mask: Optional[torch.Tensor] = None,
369
+ position_ids: Optional[torch.LongTensor] = None,
370
+ packed_metadata: Optional[PackedSequenceMetadata] = None,
371
+ layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
372
+ use_cache: bool = False,
373
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
374
+ attn_output, present = self.attn(
375
+ self.ln_1(hidden_states),
376
+ attention_mask=attention_mask,
377
+ position_ids=position_ids,
378
+ packed_metadata=packed_metadata,
379
+ layer_past=layer_past,
380
+ use_cache=use_cache,
381
+ )
382
+ hidden_states = hidden_states + attn_output
383
+ hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
384
+ return hidden_states, present
385
+
386
+
387
+ class NanoGPT2Model(nn.Module):
388
+ def __init__(self, config: GPT2Config, attn_implementation: str = "eager") -> None:
389
+ super().__init__()
390
+ self.config = config
391
+ self.attn_implementation = attn_implementation
392
+ self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
393
+ if self.position_embedding_type not in {"absolute", "rope"}:
394
+ raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
395
+ hidden_size = int(config.hidden_size)
396
+ self.wte = nn.Embedding(config.vocab_size, hidden_size)
397
+ self.wpe = nn.Embedding(config.n_positions, hidden_size) if self.position_embedding_type == "absolute" else nn.Identity()
398
+ self.drop = nn.Dropout(config.embd_pdrop)
399
+ self.h = nn.ModuleList(
400
+ [NanoGPT2Block(config, layer_idx=index, attn_implementation=attn_implementation) for index in range(config.n_layer)]
401
+ )
402
+ self.ln_f = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
403
+ self.gradient_checkpointing = False
404
+ self._reset_parameters()
405
+
406
+ def _reset_parameters(self) -> None:
407
+ init_std = float(self.config.initializer_range)
408
+ for module in self.modules():
409
+ if isinstance(module, nn.Linear):
410
+ nn.init.normal_(module.weight, mean=0.0, std=init_std)
411
+ if module.bias is not None:
412
+ nn.init.zeros_(module.bias)
413
+ elif isinstance(module, nn.Embedding):
414
+ nn.init.normal_(module.weight, mean=0.0, std=init_std)
415
+ elif isinstance(module, nn.LayerNorm):
416
+ nn.init.ones_(module.weight)
417
+ nn.init.zeros_(module.bias)
418
+
419
+ @staticmethod
420
+ def _normalize_num_sequences(
421
+ cu_seqlens: torch.Tensor,
422
+ num_sequences: Optional[torch.Tensor],
423
+ device: torch.device,
424
+ ) -> torch.Tensor:
425
+ if cu_seqlens.ndim == 1:
426
+ cu_seqlens = cu_seqlens.unsqueeze(0)
427
+ if num_sequences is None:
428
+ counts = []
429
+ for boundary in cu_seqlens:
430
+ diffs = boundary[1:] - boundary[:-1]
431
+ counts.append(int((diffs > 0).sum().item()))
432
+ return torch.tensor(counts, dtype=torch.int32, device=device)
433
+ if num_sequences.ndim == 0:
434
+ return num_sequences.unsqueeze(0)
435
+ return num_sequences
436
+
437
+ @staticmethod
438
+ def build_packed_position_ids(
439
+ attention_mask: Optional[torch.Tensor],
440
+ cu_seqlens: torch.Tensor,
441
+ num_sequences: Optional[torch.Tensor],
442
+ ) -> torch.Tensor:
443
+ if cu_seqlens.ndim == 1:
444
+ cu_seqlens = cu_seqlens.unsqueeze(0)
445
+ batch_size, seq_len = cu_seqlens.shape[0], cu_seqlens.shape[1] - 1
446
+ device = cu_seqlens.device
447
+ position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
448
+ counts = NanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
449
+ for batch_index in range(batch_size):
450
+ sequence_count = int(counts[batch_index].item())
451
+ boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
452
+ for start, end in zip(boundaries[:-1], boundaries[1:]):
453
+ start = int(start)
454
+ end = int(end)
455
+ if end > start:
456
+ position_ids[batch_index, start:end] = torch.arange(end - start, device=device)
457
+ if attention_mask is not None:
458
+ position_ids = position_ids * attention_mask.to(dtype=position_ids.dtype)
459
+ return position_ids
460
+
461
+ @staticmethod
462
+ def build_packed_metadata(
463
+ hidden_states: torch.Tensor,
464
+ cu_seqlens: torch.Tensor,
465
+ num_sequences: Optional[torch.Tensor],
466
+ ) -> PackedSequenceMetadata:
467
+ if cu_seqlens.ndim == 1:
468
+ cu_seqlens = cu_seqlens.unsqueeze(0)
469
+ device = hidden_states.device
470
+ counts = NanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
471
+ flat_indices = []
472
+ cumulative = [0]
473
+ max_seqlen = 0
474
+ seq_len = hidden_states.shape[1]
475
+
476
+ for batch_index in range(hidden_states.shape[0]):
477
+ sequence_count = int(counts[batch_index].item())
478
+ boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
479
+ for start, end in zip(boundaries[:-1], boundaries[1:]):
480
+ start = int(start)
481
+ end = int(end)
482
+ if end <= start:
483
+ continue
484
+ segment_indices = batch_index * seq_len + torch.arange(start, end, device=device)
485
+ flat_indices.append(segment_indices)
486
+ cumulative.append(cumulative[-1] + (end - start))
487
+ max_seqlen = max(max_seqlen, end - start)
488
+
489
+ if not flat_indices:
490
+ raise ValueError("cu_seqlens did not describe any non-empty packed sequences.")
491
+
492
+ indices = torch.cat(flat_indices, dim=0)
493
+ return PackedSequenceMetadata(
494
+ cu_seqlens=torch.tensor(cumulative, dtype=torch.int32, device=device),
495
+ max_seqlen=max_seqlen,
496
+ indices=indices,
497
+ batch_size=hidden_states.shape[0],
498
+ seq_len=hidden_states.shape[1],
499
+ )
500
+
501
+ def forward(
502
+ self,
503
+ input_ids: Optional[torch.LongTensor] = None,
504
+ past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
505
+ attention_mask: Optional[torch.Tensor] = None,
506
+ position_ids: Optional[torch.LongTensor] = None,
507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
508
+ use_cache: Optional[bool] = None,
509
+ output_attentions: Optional[bool] = None,
510
+ output_hidden_states: Optional[bool] = None,
511
+ return_dict: bool = True,
512
+ cu_seqlens: Optional[torch.Tensor] = None,
513
+ num_sequences: Optional[torch.Tensor] = None,
514
+ ) -> BaseModelOutputWithPast:
515
+ del input_ids, output_attentions
516
+
517
+ if inputs_embeds is None:
518
+ raise ValueError("inputs_embeds must be provided.")
519
+
520
+ use_cache = bool(use_cache)
521
+ if use_cache and cu_seqlens is not None:
522
+ raise ValueError("use_cache=True is not supported together with cu_seqlens packing.")
523
+
524
+ hidden_states = inputs_embeds
525
+ if attention_mask is None:
526
+ attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
527
+ else:
528
+ attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_states.device)
529
+ query_attention_mask = attention_mask[:, -hidden_states.shape[1] :]
530
+
531
+ packed_metadata = None
532
+ if position_ids is None:
533
+ if cu_seqlens is not None:
534
+ position_ids = self.build_packed_position_ids(
535
+ attention_mask=attention_mask,
536
+ cu_seqlens=cu_seqlens.to(device=hidden_states.device),
537
+ num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
538
+ )
539
+ elif attention_mask is not None:
540
+ position_ids = attention_mask.long().cumsum(dim=-1) - 1
541
+ position_ids = position_ids.masked_fill(~attention_mask, 0)
542
+ position_ids = position_ids[:, -hidden_states.shape[1] :]
543
+ else:
544
+ past_length = 0
545
+ if past_key_values is not None and len(past_key_values) > 0:
546
+ past_length = past_key_values[0][0].shape[1]
547
+ position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device, dtype=torch.long)
548
+ position_ids = position_ids + past_length
549
+ position_ids = position_ids.unsqueeze(0).expand(hidden_states.shape[0], -1)
550
+
551
+ if cu_seqlens is not None and self.attn_implementation == "flash_attention_2":
552
+ packed_metadata = self.build_packed_metadata(
553
+ hidden_states=hidden_states,
554
+ cu_seqlens=cu_seqlens.to(device=hidden_states.device),
555
+ num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
556
+ )
557
+
558
+ if self.position_embedding_type == "absolute":
559
+ hidden_states = hidden_states + self.wpe(position_ids)
560
+ hidden_states = self.drop(hidden_states)
561
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
562
+
563
+ all_hidden_states = () if output_hidden_states else None
564
+ presents = [] if use_cache else None
565
+ for layer_index, block in enumerate(self.h):
566
+ if output_hidden_states:
567
+ all_hidden_states = all_hidden_states + (hidden_states,)
568
+
569
+ if self.gradient_checkpointing and self.training:
570
+ if use_cache:
571
+ raise ValueError("use_cache=True is not supported when gradient checkpointing is enabled during training.")
572
+
573
+ def custom_forward(*inputs):
574
+ output, _ = block(
575
+ inputs[0],
576
+ attention_mask=inputs[1],
577
+ position_ids=inputs[2],
578
+ packed_metadata=packed_metadata,
579
+ layer_past=None,
580
+ use_cache=False,
581
+ )
582
+ return output
583
+
584
+ hidden_states = torch.utils.checkpoint.checkpoint(
585
+ custom_forward,
586
+ hidden_states,
587
+ attention_mask,
588
+ position_ids,
589
+ use_reentrant=False,
590
+ )
591
+ present = None
592
+ else:
593
+ hidden_states, present = block(
594
+ hidden_states,
595
+ attention_mask=attention_mask,
596
+ position_ids=position_ids,
597
+ packed_metadata=packed_metadata,
598
+ layer_past=None if past_key_values is None else past_key_values[layer_index],
599
+ use_cache=use_cache,
600
+ )
601
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
602
+ if presents is not None:
603
+ presents.append(present)
604
+
605
+ hidden_states = self.ln_f(hidden_states)
606
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
607
+ if output_hidden_states:
608
+ all_hidden_states = all_hidden_states + (hidden_states,)
609
+
610
+ if not return_dict:
611
+ return (hidden_states, tuple(presents) if presents is not None else None, all_hidden_states, None)
612
+
613
+ return BaseModelOutputWithPast(
614
+ last_hidden_state=hidden_states,
615
+ past_key_values=tuple(presents) if presents is not None else None,
616
+ hidden_states=all_hidden_states,
617
+ attentions=None,
618
+ )
weights/tts/modeling_nanotts_global_local.py ADDED
The diff for this file is too large to render. See raw diff
 
weights/tts/prompting.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Sequence
4
+
5
+ from .configuration_nanotts import NanoTTSConfig
6
+
7
+
8
+ USER_ROLE_PREFIX = "user\n"
9
+ USER_TEMPLATE_REFERENCE_PREFIX = (
10
+ "<user_inst>\n"
11
+ "- Reference(s):\n"
12
+ )
13
+ USER_TEMPLATE_AFTER_REFERENCE = (
14
+ "\n- Instruction:\nNone\n"
15
+ "- Tokens:\nNone\n"
16
+ "- Quality:\nNone\n"
17
+ "- Sound Event:\nNone\n"
18
+ "- Ambient Sound:\nNone\n"
19
+ "- Language:\nNone\n"
20
+ "- Text:\n"
21
+ )
22
+ USER_TEMPLATE_PREFIX = USER_TEMPLATE_REFERENCE_PREFIX + "None" + USER_TEMPLATE_AFTER_REFERENCE
23
+ USER_TEMPLATE_SUFFIX = "\n</user_inst>"
24
+ ASSISTANT_TURN_PREFIX = "\n"
25
+ ASSISTANT_ROLE_PREFIX = "assistant\n"
26
+
27
+
28
+ def encode_text(tokenizer, text: str) -> List[int]:
29
+ try:
30
+ return list(tokenizer.encode(text, add_special_tokens=False))
31
+ except TypeError:
32
+ return list(tokenizer.encode(text))
33
+
34
+
35
+ def decode_text(tokenizer, token_ids: Sequence[int]) -> str:
36
+ try:
37
+ return str(
38
+ tokenizer.decode(
39
+ list(token_ids),
40
+ skip_special_tokens=False,
41
+ clean_up_tokenization_spaces=False,
42
+ )
43
+ )
44
+ except TypeError:
45
+ try:
46
+ return str(tokenizer.decode(list(token_ids), skip_special_tokens=False))
47
+ except TypeError:
48
+ return str(tokenizer.decode(list(token_ids)))
49
+
50
+
51
+ def build_user_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
52
+ return [config.im_start_token_id] + encode_text(tokenizer, USER_ROLE_PREFIX) + encode_text(
53
+ tokenizer,
54
+ USER_TEMPLATE_REFERENCE_PREFIX,
55
+ )
56
+
57
+
58
+ def build_user_prompt_after_reference(tokenizer) -> List[int]:
59
+ return encode_text(tokenizer, USER_TEMPLATE_AFTER_REFERENCE)
60
+
61
+
62
+ def build_assistant_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
63
+ return encode_text(tokenizer, USER_TEMPLATE_SUFFIX) + [config.im_end_token_id] + encode_text(
64
+ tokenizer,
65
+ ASSISTANT_TURN_PREFIX,
66
+ ) + [config.im_start_token_id] + encode_text(
67
+ tokenizer,
68
+ ASSISTANT_ROLE_PREFIX,
69
+ )
70
+
71
+
72
+ def build_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
73
+ return (
74
+ build_user_prompt_prefix(tokenizer, config)
75
+ + encode_text(tokenizer, "None")
76
+ + build_user_prompt_after_reference(tokenizer)
77
+ )
78
+
79
+
80
+ def build_prompt_suffix(tokenizer, config: NanoTTSConfig) -> List[int]:
81
+ return build_assistant_prompt_prefix(tokenizer, config)
82
+
83
+
84
+ def build_prompt_token_ids(
85
+ tokenizer,
86
+ config: NanoTTSConfig,
87
+ text_token_ids: Sequence[int],
88
+ ) -> List[int]:
89
+ return build_prompt_prefix(tokenizer, config) + [int(token_id) for token_id in text_token_ids] + build_prompt_suffix(
90
+ tokenizer,
91
+ config,
92
+ )
weights/tts/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24003f2f11ac8a2cbf70514db2d8f1c02fb451aa6b3c0bffc9da09f31cd7caa5
3
+ size 234693095
weights/tts/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
weights/tts/tokenization_nanotts_sentencepiece.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import sentencepiece as spm
8
+ from transformers import PreTrainedTokenizer
9
+
10
+
11
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
12
+
13
+
14
+ class NanoTTSSentencePieceTokenizer(PreTrainedTokenizer):
15
+ vocab_files_names = VOCAB_FILES_NAMES
16
+ model_input_names = ["input_ids", "attention_mask"]
17
+
18
+ def __init__(
19
+ self,
20
+ vocab_file: str,
21
+ unk_token: str = "<unk>",
22
+ bos_token: str = "<s>",
23
+ eos_token: str = "</s>",
24
+ pad_token: str = "<pad>",
25
+ sp_model_kwargs: dict[str, Any] | None = None,
26
+ **kwargs,
27
+ ) -> None:
28
+ self.vocab_file = str(vocab_file)
29
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else dict(sp_model_kwargs)
30
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
31
+ self.sp_model.Load(self.vocab_file)
32
+ super().__init__(
33
+ unk_token=unk_token,
34
+ bos_token=bos_token,
35
+ eos_token=eos_token,
36
+ pad_token=pad_token,
37
+ **kwargs,
38
+ )
39
+
40
+ @property
41
+ def vocab_size(self) -> int:
42
+ return int(self.sp_model.get_piece_size())
43
+
44
+ def get_vocab(self) -> dict[str, int]:
45
+ vocab = {self.sp_model.id_to_piece(i): i for i in range(self.vocab_size)}
46
+ vocab.update(self.added_tokens_encoder)
47
+ return vocab
48
+
49
+ def _tokenize(self, text: str) -> list[str]:
50
+ return list(self.sp_model.encode(text, out_type=str))
51
+
52
+ def _convert_token_to_id(self, token: str) -> int:
53
+ token_id = int(self.sp_model.piece_to_id(token))
54
+ return token_id
55
+
56
+ def _convert_id_to_token(self, index: int) -> str:
57
+ return str(self.sp_model.id_to_piece(int(index)))
58
+
59
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
60
+ return str(self.sp_model.decode(tokens))
61
+
62
+ def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
63
+ save_dir = Path(save_directory)
64
+ save_dir.mkdir(parents=True, exist_ok=True)
65
+ out_name = "tokenizer.model" if filename_prefix is None else f"{filename_prefix}-tokenizer.model"
66
+ out_path = save_dir / out_name
67
+ if Path(self.vocab_file).resolve() != out_path.resolve():
68
+ shutil.copyfile(self.vocab_file, out_path)
69
+ return (str(out_path),)
70
+
71
+ def build_inputs_with_special_tokens(
72
+ self,
73
+ token_ids_0: list[int],
74
+ token_ids_1: list[int] | None = None,
75
+ ) -> list[int]:
76
+ if token_ids_1 is None:
77
+ return list(token_ids_0)
78
+ return list(token_ids_0) + list(token_ids_1)
79
+
80
+ def get_special_tokens_mask(
81
+ self,
82
+ token_ids_0: list[int],
83
+ token_ids_1: list[int] | None = None,
84
+ already_has_special_tokens: bool = False,
85
+ ) -> list[int]:
86
+ if already_has_special_tokens:
87
+ return super().get_special_tokens_mask(
88
+ token_ids_0=token_ids_0,
89
+ token_ids_1=token_ids_1,
90
+ already_has_special_tokens=True,
91
+ )
92
+ if token_ids_1 is None:
93
+ return [0] * len(token_ids_0)
94
+ return [0] * (len(token_ids_0) + len(token_ids_1))
95
+
96
+ def create_token_type_ids_from_sequences(
97
+ self,
98
+ token_ids_0: list[int],
99
+ token_ids_1: list[int] | None = None,
100
+ ) -> list[int]:
101
+ if token_ids_1 is None:
102
+ return [0] * len(token_ids_0)
103
+ return [0] * (len(token_ids_0) + len(token_ids_1))
weights/tts/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c353ee1479b536bf414c1b247f5542b6607fb8ae91320e5af1781fee200fddff
3
+ size 470897
weights/tts/tokenizer_config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<unk>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<pad>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "additional_special_tokens": [],
37
+ "auto_map": {
38
+ "AutoTokenizer": [
39
+ "tokenization_nanotts_sentencepiece.NanoTTSSentencePieceTokenizer",
40
+ null
41
+ ]
42
+ },
43
+ "backend": "custom",
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": false,
46
+ "eos_token": "</s>",
47
+ "extra_special_tokens": {},
48
+ "model_max_length": 16384,
49
+ "pad_token": "<pad>",
50
+ "tokenizer_class": "NanoTTSSentencePieceTokenizer",
51
+ "unk_token": "<unk>"
52
+ }