Upload folder using huggingface_hub
Browse files- README.md +53 -3
- config.json +33 -0
- model.safetensors +3 -0
- tts_mlx/__init__.py +3 -0
- tts_mlx/config.py +94 -0
- tts_mlx/inference.py +303 -0
- tts_mlx/model.py +375 -0
- tts_mlx/tokenizer.py +48 -0
README.md
CHANGED
|
@@ -1,3 +1,53 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- bg
|
| 4 |
+
tags:
|
| 5 |
+
- text-to-speech
|
| 6 |
+
- bulgarian
|
| 7 |
+
- mlx
|
| 8 |
+
- apple-silicon
|
| 9 |
+
library_name: mlx
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# 🇧🇬 BG-TTS V5 — MLX (Apple Silicon)
|
| 14 |
+
|
| 15 |
+
Native MLX port of [beleata74/bg-tts-v5](https://huggingface.co/beleata74/bg-tts-v5) for Apple Silicon (M1/M2/M3/M4).
|
| 16 |
+
|
| 17 |
+
No CUDA, no NeMo, no PyTorch required. Runs fully on Apple Silicon via MLX.
|
| 18 |
+
|
| 19 |
+
## Requirements
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install mlx soundfile numpy
|
| 23 |
+
pip install "nanocodec-mlx @ git+https://github.com/nineninesix-ai/nanocodec-mlx.git"
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Quick Start
|
| 27 |
+
|
| 28 |
+
```python
|
| 29 |
+
from tts_mlx.inference import synthesize
|
| 30 |
+
|
| 31 |
+
synthesize(
|
| 32 |
+
checkpoint=".", # path to this repo
|
| 33 |
+
text="Здравейте, аз съм българска система за синтез на реч.",
|
| 34 |
+
output="output.wav",
|
| 35 |
+
speaker_id=0, # 0 = AI voice, 1 = audiobook narrator
|
| 36 |
+
temperature=0.25,
|
| 37 |
+
top_k=50,
|
| 38 |
+
top_p=0.8,
|
| 39 |
+
)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Speakers
|
| 43 |
+
|
| 44 |
+
| Speaker | Description | Best text length |
|
| 45 |
+
|---------|-------------|-----------------|
|
| 46 |
+
| 0 | AI-generated, clear & fast | Any (20–500+ chars) |
|
| 47 |
+
| 1 | Real female, audiobook narrator | 250–320 chars |
|
| 48 |
+
|
| 49 |
+
## Credits
|
| 50 |
+
|
| 51 |
+
Original model by [beleata74](https://huggingface.co/beleata74/bg-tts-v5), created with Claude.
|
| 52 |
+
MLX port by Radi Totev.
|
| 53 |
+
NanoCodec MLX by [nineninesix-ai](https://github.com/nineninesix-ai/nanocodec-mlx).
|
config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "bg-tts-v5-mlx",
|
| 3 |
+
"framework": "mlx",
|
| 4 |
+
"language": "bg",
|
| 5 |
+
"encoder": {
|
| 6 |
+
"vocab_size": 155,
|
| 7 |
+
"d_model": 512,
|
| 8 |
+
"n_heads": 8,
|
| 9 |
+
"n_layers": 6,
|
| 10 |
+
"d_ff": 2048,
|
| 11 |
+
"max_len": 512
|
| 12 |
+
},
|
| 13 |
+
"decoder": {
|
| 14 |
+
"vocab_size": 16283,
|
| 15 |
+
"d_model": 768,
|
| 16 |
+
"n_heads": 12,
|
| 17 |
+
"n_layers": 18,
|
| 18 |
+
"d_ff": 3072,
|
| 19 |
+
"max_len": 2048,
|
| 20 |
+
"tokens_per_frame": 4
|
| 21 |
+
},
|
| 22 |
+
"codec": {
|
| 23 |
+
"model": "nineninesix/nemo-nano-codec-22khz-0.6kbps-12.5fps-MLX",
|
| 24 |
+
"sample_rate": 22050,
|
| 25 |
+
"num_codebooks": 4,
|
| 26 |
+
"codebook_size": 4032,
|
| 27 |
+
"frame_rate": 12.5
|
| 28 |
+
},
|
| 29 |
+
"speakers": {
|
| 30 |
+
"0": "AI-generated female voice, clear and fast",
|
| 31 |
+
"1": "Real female voice, audiobook narrator (use 250-320 chars)"
|
| 32 |
+
}
|
| 33 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e473ad925047a7300f80fdb98afa9c80f7d1ab6b4e0f81a726e05d08738d38d
|
| 3 |
+
size 1003201311
|
tts_mlx/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .inference import synthesize, load_from_pytorch_checkpoint
|
| 2 |
+
from .model import TTSEncoderDecoder, V5Config
|
| 3 |
+
from .tokenizer import TTSTokenizer
|
tts_mlx/config.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Identical to original tts_v5/config.py — no changes needed
|
| 2 |
+
|
| 3 |
+
NANOCODEC_MODEL_NAME = "nineninesix/nemo-nano-codec-22khz-0.6kbps-12.5fps-MLX" # MLX version
|
| 4 |
+
CODEC_SAMPLE_RATE = 22_050
|
| 5 |
+
CODEC_NUM_CODEBOOKS = 4
|
| 6 |
+
CODEC_CODEBOOK_SIZE = 4_032
|
| 7 |
+
CODEC_FRAME_RATE = 12.5
|
| 8 |
+
CODEC_TOKENS_PER_SEC = 50
|
| 9 |
+
TOKENS_PER_FRAME = 4
|
| 10 |
+
|
| 11 |
+
BG_LOWER = "абвгдежзийклмнопрстуфхцчшщъьюя"
|
| 12 |
+
BG_UPPER = "АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯ"
|
| 13 |
+
EN_LOWER = "abcdefghijklmnopqrstuvwxyz"
|
| 14 |
+
EN_UPPER = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
| 15 |
+
DIGITS = "0123456789"
|
| 16 |
+
PUNCT = '.,!?;:-–—…"\'()[]{}«»„"" '
|
| 17 |
+
EXTRA = "\n\t"
|
| 18 |
+
|
| 19 |
+
_ALL_CHARS: list[str] = []
|
| 20 |
+
_seen: set[str] = set()
|
| 21 |
+
for _src in [BG_LOWER, BG_UPPER, EN_LOWER, EN_UPPER, DIGITS, PUNCT, EXTRA]:
|
| 22 |
+
for _ch in _src:
|
| 23 |
+
if _ch not in _seen:
|
| 24 |
+
_ALL_CHARS.append(_ch)
|
| 25 |
+
_seen.add(_ch)
|
| 26 |
+
|
| 27 |
+
SPECIAL_TOKENS = {
|
| 28 |
+
"<pad>": 0,
|
| 29 |
+
"<start_of_text>": 1,
|
| 30 |
+
"<end_of_text>": 2,
|
| 31 |
+
"<start_of_speech>": 3,
|
| 32 |
+
"<end_of_speech>": 4,
|
| 33 |
+
"<spk_0>": 5,
|
| 34 |
+
"<spk_1>": 6,
|
| 35 |
+
"<spk_2>": 7,
|
| 36 |
+
"<spk_3>": 8,
|
| 37 |
+
}
|
| 38 |
+
NUM_SPECIAL_TOKENS = len(SPECIAL_TOKENS)
|
| 39 |
+
|
| 40 |
+
TEXT_CHARS = _ALL_CHARS
|
| 41 |
+
TEXT_VOCAB_SIZE = len(TEXT_CHARS)
|
| 42 |
+
TEXT_OFFSET = NUM_SPECIAL_TOKENS
|
| 43 |
+
AUDIO_OFFSET = TEXT_OFFSET + TEXT_VOCAB_SIZE
|
| 44 |
+
NUM_AUDIO_TOKENS = CODEC_NUM_CODEBOOKS * CODEC_CODEBOOK_SIZE
|
| 45 |
+
TOTAL_VOCAB_SIZE = AUDIO_OFFSET + NUM_AUDIO_TOKENS
|
| 46 |
+
|
| 47 |
+
ENCODER_VOCAB_SIZE = AUDIO_OFFSET
|
| 48 |
+
DECODER_VOCAB_SIZE = TOTAL_VOCAB_SIZE
|
| 49 |
+
|
| 50 |
+
PAD_TOKEN_ID = SPECIAL_TOKENS["<pad>"]
|
| 51 |
+
START_OF_TEXT_TOKEN_ID = SPECIAL_TOKENS["<start_of_text>"]
|
| 52 |
+
END_OF_TEXT_TOKEN_ID = SPECIAL_TOKENS["<end_of_text>"]
|
| 53 |
+
START_OF_SPEECH_TOKEN_ID = SPECIAL_TOKENS["<start_of_speech>"]
|
| 54 |
+
END_OF_SPEECH_TOKEN_ID = SPECIAL_TOKENS["<end_of_speech>"]
|
| 55 |
+
SPK_0_TOKEN_ID = SPECIAL_TOKENS["<spk_0>"]
|
| 56 |
+
SPK_1_TOKEN_ID = SPECIAL_TOKENS["<spk_1>"]
|
| 57 |
+
|
| 58 |
+
def audio_token_id(codebook: int, code: int) -> int:
|
| 59 |
+
return AUDIO_OFFSET + codebook * CODEC_CODEBOOK_SIZE + code
|
| 60 |
+
|
| 61 |
+
def decode_audio_token(token_id: int) -> tuple[int, int]:
|
| 62 |
+
offset = token_id - AUDIO_OFFSET
|
| 63 |
+
return offset // CODEC_CODEBOOK_SIZE, offset % CODEC_CODEBOOK_SIZE
|
| 64 |
+
|
| 65 |
+
def is_audio_token(token_id: int) -> bool:
|
| 66 |
+
return AUDIO_OFFSET <= token_id < AUDIO_OFFSET + NUM_AUDIO_TOKENS
|
| 67 |
+
|
| 68 |
+
def is_special_token(token_id: int) -> bool:
|
| 69 |
+
return 0 <= token_id < NUM_SPECIAL_TOKENS
|
| 70 |
+
|
| 71 |
+
def is_text_token(token_id: int) -> bool:
|
| 72 |
+
return TEXT_OFFSET <= token_id < AUDIO_OFFSET
|
| 73 |
+
|
| 74 |
+
ENC_D_MODEL = 512
|
| 75 |
+
ENC_N_HEADS = 8
|
| 76 |
+
ENC_N_LAYERS = 6
|
| 77 |
+
ENC_D_FF = 2048
|
| 78 |
+
|
| 79 |
+
DEC_D_MODEL = 768
|
| 80 |
+
DEC_N_HEADS = 12
|
| 81 |
+
DEC_N_LAYERS = 18
|
| 82 |
+
DEC_D_FF = 3072
|
| 83 |
+
|
| 84 |
+
MAX_TEXT_LEN = 512
|
| 85 |
+
MAX_AUDIO_LEN = 2048
|
| 86 |
+
DROPOUT = 0.10
|
| 87 |
+
CTC_WEIGHT = 0.1
|
| 88 |
+
|
| 89 |
+
BATCH_SIZE = 8
|
| 90 |
+
GRAD_ACCUM = 2
|
| 91 |
+
LR = 3e-4
|
| 92 |
+
WEIGHT_DECAY = 0.1
|
| 93 |
+
WARMUP_STEPS = 500
|
| 94 |
+
NUM_EPOCHS = 3
|
tts_mlx/inference.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLX Inference — Encoder-Decoder TTS
|
| 3 |
+
=====================================
|
| 4 |
+
1. Load PyTorch checkpoint weights → convert to MLX arrays
|
| 5 |
+
2. Encode text with encoder (once, bidirectional)
|
| 6 |
+
3. Cache cross-attention KVs from encoder (computed once per layer)
|
| 7 |
+
4. Autoregressively decode audio tokens
|
| 8 |
+
5. Decode tokens → wav using nanocodec-mlx
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import math
|
| 13 |
+
import numpy as np
|
| 14 |
+
import mlx.core as mx
|
| 15 |
+
import soundfile as sf
|
| 16 |
+
|
| 17 |
+
from .config import (
|
| 18 |
+
AUDIO_OFFSET, NUM_AUDIO_TOKENS, END_OF_SPEECH_TOKEN_ID,
|
| 19 |
+
START_OF_SPEECH_TOKEN_ID, CODEC_NUM_CODEBOOKS, NANOCODEC_MODEL_NAME,
|
| 20 |
+
)
|
| 21 |
+
from .tokenizer import TTSTokenizer
|
| 22 |
+
from .model import TTSEncoderDecoder, V5Config
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ── Weight Loading ─────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
def _pt_to_mx(t):
|
| 28 |
+
"""Convert PyTorch tensor to MLX array."""
|
| 29 |
+
return mx.array(t.float().numpy())
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_from_pytorch_checkpoint(checkpoint_path: str) -> TTSEncoderDecoder:
|
| 33 |
+
"""
|
| 34 |
+
Load PyTorch checkpoint and convert weights to MLX.
|
| 35 |
+
"""
|
| 36 |
+
import torch # only needed when loading from PyTorch checkpoint
|
| 37 |
+
ckpt_file = os.path.join(checkpoint_path, "checkpoint.pt")
|
| 38 |
+
print(f"📂 Loading checkpoint: {ckpt_file}")
|
| 39 |
+
ckpt = torch.load(ckpt_file, map_location="cpu", weights_only=False)
|
| 40 |
+
|
| 41 |
+
cfg = ckpt["config"]
|
| 42 |
+
config = V5Config(
|
| 43 |
+
enc_vocab_size=cfg["enc_vocab_size"],
|
| 44 |
+
enc_d_model=cfg["enc_d_model"],
|
| 45 |
+
enc_n_heads=cfg["enc_n_heads"],
|
| 46 |
+
enc_n_layers=cfg["enc_n_layers"],
|
| 47 |
+
enc_d_ff=cfg["enc_d_ff"],
|
| 48 |
+
max_text_len=cfg["max_text_len"],
|
| 49 |
+
dec_vocab_size=cfg["dec_vocab_size"],
|
| 50 |
+
dec_d_model=cfg["dec_d_model"],
|
| 51 |
+
dec_n_heads=cfg["dec_n_heads"],
|
| 52 |
+
dec_n_layers=cfg["dec_n_layers"],
|
| 53 |
+
dec_d_ff=cfg["dec_d_ff"],
|
| 54 |
+
max_audio_len=cfg["max_audio_len"],
|
| 55 |
+
dropout=0.0,
|
| 56 |
+
ctc_weight=0.0,
|
| 57 |
+
tokens_per_frame=cfg.get("tokens_per_frame", 1),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
model = TTSEncoderDecoder(config)
|
| 61 |
+
state = ckpt["model_state_dict"]
|
| 62 |
+
|
| 63 |
+
# Build MLX weight dict by mapping PyTorch keys → MLX parameter paths
|
| 64 |
+
mlx_weights = {}
|
| 65 |
+
for key, val in state.items():
|
| 66 |
+
# Skip CTC head (not needed for inference)
|
| 67 |
+
if key.startswith("ctc_head"):
|
| 68 |
+
continue
|
| 69 |
+
# Convert key format: PyTorch uses dots, MLX uses same
|
| 70 |
+
# e.g. "encoder.layers.0.attention.q_proj.weight" stays the same
|
| 71 |
+
mlx_weights[key] = _pt_to_mx(val)
|
| 72 |
+
|
| 73 |
+
model.load_weights(list(mlx_weights.items()), strict=False)
|
| 74 |
+
mx.eval(model.parameters())
|
| 75 |
+
|
| 76 |
+
step = ckpt.get("step", "?")
|
| 77 |
+
loss = ckpt.get("loss", 0.0)
|
| 78 |
+
print(f"✅ Loaded! step={step}, loss={loss:.4f}, tpf={config.tokens_per_frame}")
|
| 79 |
+
return model
|
| 80 |
+
|
| 81 |
+
def load_from_safetensors(repo_path: str) -> TTSEncoderDecoder:
|
| 82 |
+
"""Load MLX model from safetensors — no PyTorch required."""
|
| 83 |
+
import json
|
| 84 |
+
weights_file = os.path.join(repo_path, "model.safetensors")
|
| 85 |
+
config_file = os.path.join(repo_path, "config.json")
|
| 86 |
+
|
| 87 |
+
with open(config_file) as f:
|
| 88 |
+
cfg = json.load(f)
|
| 89 |
+
|
| 90 |
+
config = V5Config(
|
| 91 |
+
enc_vocab_size=cfg["encoder"]["vocab_size"],
|
| 92 |
+
enc_d_model=cfg["encoder"]["d_model"],
|
| 93 |
+
enc_n_heads=cfg["encoder"]["n_heads"],
|
| 94 |
+
enc_n_layers=cfg["encoder"]["n_layers"],
|
| 95 |
+
enc_d_ff=cfg["encoder"]["d_ff"],
|
| 96 |
+
max_text_len=cfg["encoder"]["max_len"],
|
| 97 |
+
dec_vocab_size=cfg["decoder"]["vocab_size"],
|
| 98 |
+
dec_d_model=cfg["decoder"]["d_model"],
|
| 99 |
+
dec_n_heads=cfg["decoder"]["n_heads"],
|
| 100 |
+
dec_n_layers=cfg["decoder"]["n_layers"],
|
| 101 |
+
dec_d_ff=cfg["decoder"]["d_ff"],
|
| 102 |
+
max_audio_len=cfg["decoder"]["max_len"],
|
| 103 |
+
tokens_per_frame=cfg["decoder"]["tokens_per_frame"],
|
| 104 |
+
dropout=0.0,
|
| 105 |
+
ctc_weight=0.0,
|
| 106 |
+
)
|
| 107 |
+
model = TTSEncoderDecoder(config)
|
| 108 |
+
model.load_weights(weights_file, strict=False)
|
| 109 |
+
mx.eval(model.parameters())
|
| 110 |
+
print(f"✅ Loaded from safetensors!")
|
| 111 |
+
return model
|
| 112 |
+
|
| 113 |
+
# ── Generation ─────────────────────────────────────────────────
|
| 114 |
+
|
| 115 |
+
def sample_token(logits: mx.array, temperature: float, top_k: int, top_p: float,
|
| 116 |
+
recent_tokens: list, rep_penalty: float) -> int:
|
| 117 |
+
"""Sample next token from logits."""
|
| 118 |
+
# Mask: only audio tokens + eos allowed
|
| 119 |
+
mask = mx.full(logits.shape, -1e9)
|
| 120 |
+
# Allow audio tokens
|
| 121 |
+
audio_slice = mx.zeros((NUM_AUDIO_TOKENS,))
|
| 122 |
+
mask = mx.concatenate([
|
| 123 |
+
mask[:AUDIO_OFFSET],
|
| 124 |
+
audio_slice,
|
| 125 |
+
], axis=0)
|
| 126 |
+
# Allow end of speech
|
| 127 |
+
eos_allow = mx.zeros((1,))
|
| 128 |
+
mask_list = mx.array(
|
| 129 |
+
[-1e9] * AUDIO_OFFSET +
|
| 130 |
+
[0.0] * NUM_AUDIO_TOKENS
|
| 131 |
+
)
|
| 132 |
+
# Simpler: build as numpy, set allowed positions
|
| 133 |
+
mask_np = np.full(logits.shape, -1e9, dtype=np.float32)
|
| 134 |
+
mask_np[AUDIO_OFFSET: AUDIO_OFFSET + NUM_AUDIO_TOKENS] = 0.0
|
| 135 |
+
mask_np[END_OF_SPEECH_TOKEN_ID] = 0.0
|
| 136 |
+
logits_np = np.array(logits) + mask_np
|
| 137 |
+
|
| 138 |
+
# Repetition penalty
|
| 139 |
+
if rep_penalty != 1.0 and recent_tokens:
|
| 140 |
+
for tid in set(recent_tokens[-200:]):
|
| 141 |
+
if AUDIO_OFFSET <= tid < AUDIO_OFFSET + NUM_AUDIO_TOKENS:
|
| 142 |
+
logits_np[tid] /= rep_penalty
|
| 143 |
+
|
| 144 |
+
# Temperature
|
| 145 |
+
logits_np = logits_np / temperature
|
| 146 |
+
|
| 147 |
+
# Top-k
|
| 148 |
+
if top_k > 0:
|
| 149 |
+
kth_val = np.partition(logits_np, -min(top_k, len(logits_np)))[-min(top_k, len(logits_np))]
|
| 150 |
+
logits_np[logits_np < kth_val] = -1e9
|
| 151 |
+
|
| 152 |
+
# Top-p
|
| 153 |
+
if top_p < 1.0:
|
| 154 |
+
sorted_idx = np.argsort(logits_np)[::-1]
|
| 155 |
+
sorted_logits = logits_np[sorted_idx]
|
| 156 |
+
probs = np.exp(sorted_logits - sorted_logits[0])
|
| 157 |
+
probs /= probs.sum()
|
| 158 |
+
cum = np.cumsum(probs)
|
| 159 |
+
remove = cum > top_p
|
| 160 |
+
remove[1:] = remove[:-1].copy()
|
| 161 |
+
remove[0] = False
|
| 162 |
+
logits_np[sorted_idx[remove]] = -1e9
|
| 163 |
+
|
| 164 |
+
# Sample
|
| 165 |
+
probs = np.exp(logits_np - logits_np.max())
|
| 166 |
+
probs /= probs.sum()
|
| 167 |
+
return int(np.random.choice(len(probs), p=probs))
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate(model: TTSEncoderDecoder, tokenizer: TTSTokenizer,
|
| 171 |
+
text: str, speaker_id: int = 0,
|
| 172 |
+
max_new_tokens: int = 2000, temperature: float = 0.25,
|
| 173 |
+
top_k: int = 50, top_p: float = 0.8, rep_penalty: float = 1.1):
|
| 174 |
+
"""Generate audio tokens from text."""
|
| 175 |
+
|
| 176 |
+
# 1. Encode text (once)
|
| 177 |
+
enc_ids_np = tokenizer.build_encoder_input(text, speaker_id)
|
| 178 |
+
enc_ids = mx.array(enc_ids_np[None, :]) # [1, T_enc]
|
| 179 |
+
enc_mask = mx.ones_like(enc_ids)
|
| 180 |
+
|
| 181 |
+
enc_out = model.encode(enc_ids, enc_mask) # [1, T_enc, dec_d]
|
| 182 |
+
mx.eval(enc_out)
|
| 183 |
+
print(f"📝 Encoded: {enc_ids.shape[1]} tokens → enc_out {enc_out.shape}")
|
| 184 |
+
|
| 185 |
+
# 2. Autoregressive decode
|
| 186 |
+
dec_ids = mx.array([[START_OF_SPEECH_TOKEN_ID]])
|
| 187 |
+
past_self_kvs = None
|
| 188 |
+
cached_cross_kvs = None
|
| 189 |
+
generated = []
|
| 190 |
+
offset = 0
|
| 191 |
+
|
| 192 |
+
for step in range(max_new_tokens):
|
| 193 |
+
inp = dec_ids[:, -1:] if past_self_kvs is not None else dec_ids
|
| 194 |
+
|
| 195 |
+
logits, new_self_kvs, new_cross_kvs = model.decoder(
|
| 196 |
+
inp, enc_out, enc_mask,
|
| 197 |
+
past_key_values=past_self_kvs,
|
| 198 |
+
cached_cross_kvs=cached_cross_kvs,
|
| 199 |
+
offset=offset,
|
| 200 |
+
)
|
| 201 |
+
mx.eval(logits)
|
| 202 |
+
|
| 203 |
+
# Cache cross-attention KVs after first step (they don't change)
|
| 204 |
+
if cached_cross_kvs is None:
|
| 205 |
+
cached_cross_kvs = new_cross_kvs
|
| 206 |
+
mx.eval(cached_cross_kvs)
|
| 207 |
+
|
| 208 |
+
past_self_kvs = new_self_kvs
|
| 209 |
+
offset += inp.shape[1]
|
| 210 |
+
|
| 211 |
+
# Sample
|
| 212 |
+
last_logits = np.array(logits[0, -1, :])
|
| 213 |
+
tok_id = sample_token(last_logits, temperature, top_k, top_p, generated, rep_penalty)
|
| 214 |
+
|
| 215 |
+
if tok_id == END_OF_SPEECH_TOKEN_ID:
|
| 216 |
+
print(f"🛑 EOS at step {step}")
|
| 217 |
+
break
|
| 218 |
+
|
| 219 |
+
generated.append(tok_id)
|
| 220 |
+
dec_ids = mx.array([[tok_id]])
|
| 221 |
+
|
| 222 |
+
if step % 100 == 0 and step > 0:
|
| 223 |
+
print(f" step {step}: {len(generated)} tokens (~{len(generated)/50:.1f}s audio)")
|
| 224 |
+
|
| 225 |
+
if not generated:
|
| 226 |
+
return None
|
| 227 |
+
|
| 228 |
+
tokens = np.array(generated)
|
| 229 |
+
audio_mask = (tokens >= AUDIO_OFFSET) & (tokens < AUDIO_OFFSET + NUM_AUDIO_TOKENS)
|
| 230 |
+
return tokens[audio_mask] - AUDIO_OFFSET
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# ── Full Pipeline ──────────────────────────────────────────────
|
| 234 |
+
|
| 235 |
+
def synthesize(checkpoint: str, text: str, output: str = "output.wav",
|
| 236 |
+
speaker_id: int = 0, temperature: float = 0.25,
|
| 237 |
+
top_k: int = 50, top_p: float = 0.8, rep_penalty: float = 1.1,
|
| 238 |
+
max_tokens: int = 2000):
|
| 239 |
+
|
| 240 |
+
print(f"\n🎤 Text: '{text[:80]}'")
|
| 241 |
+
print(f" speaker={speaker_id}, T={temperature}, top_k={top_k}, top_p={top_p}")
|
| 242 |
+
|
| 243 |
+
# Load model
|
| 244 |
+
# Auto-detect: safetensors repo or PyTorch checkpoint?
|
| 245 |
+
if os.path.exists(os.path.join(checkpoint, "model.safetensors")):
|
| 246 |
+
model = load_from_safetensors(checkpoint)
|
| 247 |
+
else:
|
| 248 |
+
model = load_from_pytorch_checkpoint(checkpoint)
|
| 249 |
+
model.eval()
|
| 250 |
+
|
| 251 |
+
# Load tokenizer
|
| 252 |
+
tokenizer = TTSTokenizer()
|
| 253 |
+
|
| 254 |
+
# Generate tokens
|
| 255 |
+
tokens = generate(model, tokenizer, text, speaker_id, max_tokens,
|
| 256 |
+
temperature, top_k, top_p, rep_penalty)
|
| 257 |
+
|
| 258 |
+
if tokens is None or len(tokens) == 0:
|
| 259 |
+
print("❌ No audio generated!")
|
| 260 |
+
return
|
| 261 |
+
|
| 262 |
+
# Trim to multiple of 4 codebooks
|
| 263 |
+
tokens = tokens[:len(tokens) - len(tokens) % CODEC_NUM_CODEBOOKS]
|
| 264 |
+
print(f"🔊 {len(tokens)} tokens → {len(tokens)//4} frames → ~{len(tokens)//4/12.5:.1f}s audio")
|
| 265 |
+
|
| 266 |
+
# Decode with nanocodec-mlx
|
| 267 |
+
print("🎵 Decoding with NanoCodec MLX...")
|
| 268 |
+
from nanocodec_mlx.models.audio_codec import AudioCodecModel
|
| 269 |
+
|
| 270 |
+
codec = AudioCodecModel.from_pretrained(NANOCODEC_MODEL_NAME)
|
| 271 |
+
|
| 272 |
+
# Reshape tokens: [num_codebooks, num_frames]
|
| 273 |
+
num_frames = len(tokens) // CODEC_NUM_CODEBOOKS
|
| 274 |
+
codes = tokens.reshape(num_frames, CODEC_NUM_CODEBOOKS).T # [4, T]
|
| 275 |
+
codes_mx = mx.array(codes.astype(np.int32))[None, :, :] # [1, 4, T]
|
| 276 |
+
tokens_len = mx.array([num_frames], dtype=mx.int32)
|
| 277 |
+
|
| 278 |
+
wav_mx, _ = codec.decode(codes_mx, tokens_len)
|
| 279 |
+
mx.eval(wav_mx)
|
| 280 |
+
|
| 281 |
+
# Save
|
| 282 |
+
wav_np = np.array(wav_mx[0, 0, :])
|
| 283 |
+
sf.write(output, wav_np, 22050)
|
| 284 |
+
duration = len(wav_np) / 22050
|
| 285 |
+
print(f"✅ Saved: {output} ({duration:.2f}s)")
|
| 286 |
+
return wav_np
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
import argparse
|
| 291 |
+
p = argparse.ArgumentParser()
|
| 292 |
+
p.add_argument("--checkpoint", required=True)
|
| 293 |
+
p.add_argument("--text", required=True)
|
| 294 |
+
p.add_argument("--output", default="output.wav")
|
| 295 |
+
p.add_argument("--speaker", type=int, default=0)
|
| 296 |
+
p.add_argument("--temperature", type=float, default=0.25)
|
| 297 |
+
p.add_argument("--top-k", type=int, default=50)
|
| 298 |
+
p.add_argument("--top-p", type=float, default=0.8)
|
| 299 |
+
p.add_argument("--rep-penalty", type=float, default=1.1)
|
| 300 |
+
p.add_argument("--max-tokens", type=int, default=2000)
|
| 301 |
+
a = p.parse_args()
|
| 302 |
+
synthesize(a.checkpoint, a.text, a.output, a.speaker,
|
| 303 |
+
a.temperature, a.top_k, a.top_p, a.rep_penalty, a.max_tokens)
|
tts_mlx/model.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLX Model — Encoder-Decoder TTS
|
| 3 |
+
================================
|
| 4 |
+
Port of tts_v5/model.py from PyTorch to MLX.
|
| 5 |
+
Inference-only (no training, no dropout, no CTC head needed).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import mlx.core as mx
|
| 10 |
+
import mlx.nn as nn
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional, Tuple
|
| 13 |
+
|
| 14 |
+
from .config import (
|
| 15 |
+
TOTAL_VOCAB_SIZE, ENCODER_VOCAB_SIZE, DECODER_VOCAB_SIZE,
|
| 16 |
+
ENC_D_MODEL, ENC_N_HEADS, ENC_N_LAYERS, ENC_D_FF,
|
| 17 |
+
DEC_D_MODEL, DEC_N_HEADS, DEC_N_LAYERS, DEC_D_FF,
|
| 18 |
+
MAX_TEXT_LEN, MAX_AUDIO_LEN,
|
| 19 |
+
PAD_TOKEN_ID, NUM_AUDIO_TOKENS, AUDIO_OFFSET,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ── Shared Components ──────────────────────────────────────────
|
| 24 |
+
|
| 25 |
+
class RMSNorm(nn.Module):
|
| 26 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.eps = eps
|
| 29 |
+
self.weight = mx.ones((dim,))
|
| 30 |
+
|
| 31 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 32 |
+
norm = mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
|
| 33 |
+
return x * norm * self.weight
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SwiGLUFFN(nn.Module):
|
| 37 |
+
def __init__(self, d_model: int, d_ff: int):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
|
| 40 |
+
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
|
| 41 |
+
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
|
| 42 |
+
|
| 43 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 44 |
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_rope_cache(max_seq_len: int, head_dim: int, base: float = 10000.0):
|
| 48 |
+
"""Precompute RoPE cos/sin tables."""
|
| 49 |
+
inv_freq = 1.0 / (base ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
|
| 50 |
+
t = mx.arange(max_seq_len, dtype=mx.float32)
|
| 51 |
+
freqs = mx.outer(t, inv_freq)
|
| 52 |
+
emb = mx.concatenate([freqs, freqs], axis=-1)
|
| 53 |
+
return mx.cos(emb), mx.sin(emb)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def rotate_half(x: mx.array) -> mx.array:
|
| 57 |
+
half = x.shape[-1] // 2
|
| 58 |
+
x1, x2 = x[..., :half], x[..., half:]
|
| 59 |
+
return mx.concatenate([-x2, x1], axis=-1)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def apply_rope(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array):
|
| 63 |
+
cos = cos[None, None, :, :] # [1, 1, T, head_dim]
|
| 64 |
+
sin = sin[None, None, :, :]
|
| 65 |
+
q = q * cos + rotate_half(q) * sin
|
| 66 |
+
k = k * cos + rotate_half(k) * sin
|
| 67 |
+
return q, k
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ── Encoder (Bidirectional) ────────────────────────────────────
|
| 71 |
+
|
| 72 |
+
class EncoderSelfAttention(nn.Module):
|
| 73 |
+
def __init__(self, d_model: int, n_heads: int):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.n_heads = n_heads
|
| 76 |
+
self.head_dim = d_model // n_heads
|
| 77 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 78 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| 79 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| 80 |
+
self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
| 81 |
+
|
| 82 |
+
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
| 83 |
+
B, T, _ = x.shape
|
| 84 |
+
q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 85 |
+
k = self.k_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 86 |
+
v = self.v_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 87 |
+
|
| 88 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
| 89 |
+
scores = (q @ k.transpose(0, 1, 3, 2)) * scale # [B, H, T, T]
|
| 90 |
+
|
| 91 |
+
if mask is not None:
|
| 92 |
+
scores = scores + mask
|
| 93 |
+
|
| 94 |
+
attn = mx.softmax(scores.astype(mx.float32), axis=-1).astype(x.dtype)
|
| 95 |
+
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, -1)
|
| 96 |
+
return self.o_proj(out)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class EncoderBlock(nn.Module):
|
| 100 |
+
def __init__(self, d_model: int, n_heads: int, d_ff: int):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.attn_norm = RMSNorm(d_model)
|
| 103 |
+
self.attention = EncoderSelfAttention(d_model, n_heads)
|
| 104 |
+
self.ffn_norm = RMSNorm(d_model)
|
| 105 |
+
self.ffn = SwiGLUFFN(d_model, d_ff)
|
| 106 |
+
|
| 107 |
+
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
| 108 |
+
x = x + self.attention(self.attn_norm(x), mask)
|
| 109 |
+
x = x + self.ffn(self.ffn_norm(x))
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class TextEncoder(nn.Module):
|
| 114 |
+
def __init__(self, vocab_size=ENCODER_VOCAB_SIZE, d_model=ENC_D_MODEL,
|
| 115 |
+
n_heads=ENC_N_HEADS, n_layers=ENC_N_LAYERS, d_ff=ENC_D_FF,
|
| 116 |
+
max_len=MAX_TEXT_LEN):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.d_model = d_model
|
| 119 |
+
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
| 120 |
+
self.pos_embedding = nn.Embedding(max_len, d_model)
|
| 121 |
+
self.layers = [EncoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)]
|
| 122 |
+
self.final_norm = RMSNorm(d_model)
|
| 123 |
+
|
| 124 |
+
def __call__(self, input_ids: mx.array, attention_mask: Optional[mx.array] = None) -> mx.array:
|
| 125 |
+
B, T = input_ids.shape
|
| 126 |
+
pos = mx.arange(T)[None, :] # [1, T]
|
| 127 |
+
h = self.token_embedding(input_ids) + self.pos_embedding(pos)
|
| 128 |
+
|
| 129 |
+
# Build padding mask: [B, 1, 1, T], -inf on pad positions
|
| 130 |
+
attn_mask = None
|
| 131 |
+
if attention_mask is not None:
|
| 132 |
+
# attention_mask: [B, T], 1=real 0=pad
|
| 133 |
+
pad = (attention_mask == 0).astype(mx.float32) # [B, T]
|
| 134 |
+
attn_mask = pad[:, None, None, :] * -1e9 # [B, 1, 1, T]
|
| 135 |
+
|
| 136 |
+
for layer in self.layers:
|
| 137 |
+
h = layer(h, attn_mask)
|
| 138 |
+
|
| 139 |
+
return self.final_norm(h)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ── Decoder (Causal with Cross-Attention) ──────────────────────
|
| 143 |
+
|
| 144 |
+
class DecoderSelfAttention(nn.Module):
|
| 145 |
+
def __init__(self, d_model: int, n_heads: int, max_len: int, tokens_per_frame: int = 1):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.n_heads = n_heads
|
| 148 |
+
self.head_dim = d_model // n_heads
|
| 149 |
+
self.tokens_per_frame = tokens_per_frame
|
| 150 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 151 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| 152 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| 153 |
+
self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
| 154 |
+
# Precompute RoPE
|
| 155 |
+
cos, sin = build_rope_cache(max_len * 2, self.head_dim)
|
| 156 |
+
self.rope_cos = cos
|
| 157 |
+
self.rope_sin = sin
|
| 158 |
+
|
| 159 |
+
def __call__(self, x: mx.array, past_kv=None, offset: int = 0):
|
| 160 |
+
"""
|
| 161 |
+
x: [B, T, d_model]
|
| 162 |
+
past_kv: (k_cache, v_cache) or None
|
| 163 |
+
offset: number of already-generated tokens (for RoPE position)
|
| 164 |
+
Returns: (output, new_k, new_v)
|
| 165 |
+
"""
|
| 166 |
+
B, T, _ = x.shape
|
| 167 |
+
q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 168 |
+
k = self.k_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 169 |
+
v = self.v_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 170 |
+
|
| 171 |
+
# Apply RoPE with frame-level positions
|
| 172 |
+
if self.tokens_per_frame > 1:
|
| 173 |
+
frame_offset = offset // self.tokens_per_frame
|
| 174 |
+
frame_positions = mx.arange(T) // self.tokens_per_frame + frame_offset
|
| 175 |
+
else:
|
| 176 |
+
frame_positions = mx.arange(T) + offset
|
| 177 |
+
|
| 178 |
+
cos = self.rope_cos[frame_positions] # [T, head_dim]
|
| 179 |
+
sin = self.rope_sin[frame_positions]
|
| 180 |
+
q, k = apply_rope(q, k, cos, sin)
|
| 181 |
+
|
| 182 |
+
# Append to KV cache
|
| 183 |
+
if past_kv is not None:
|
| 184 |
+
k = mx.concatenate([past_kv[0], k], axis=2)
|
| 185 |
+
v = mx.concatenate([past_kv[1], v], axis=2)
|
| 186 |
+
|
| 187 |
+
new_k, new_v = k, v
|
| 188 |
+
|
| 189 |
+
# Causal mask only during prefill (T > 1, no cache)
|
| 190 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
| 191 |
+
scores = (q @ k.transpose(0, 1, 3, 2)) * scale
|
| 192 |
+
|
| 193 |
+
if past_kv is None and T > 1:
|
| 194 |
+
# Build causal mask
|
| 195 |
+
causal = mx.triu(mx.full((T, k.shape[2]), -1e9), k=1)
|
| 196 |
+
scores = scores + causal[None, None, :, :]
|
| 197 |
+
|
| 198 |
+
attn = mx.softmax(scores.astype(mx.float32), axis=-1).astype(x.dtype)
|
| 199 |
+
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, -1)
|
| 200 |
+
return self.o_proj(out), new_k, new_v
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class CrossAttention(nn.Module):
|
| 204 |
+
def __init__(self, dec_d_model: int, enc_d_model: int, n_heads: int):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.n_heads = n_heads
|
| 207 |
+
self.head_dim = dec_d_model // n_heads
|
| 208 |
+
self.q_proj = nn.Linear(dec_d_model, dec_d_model, bias=False)
|
| 209 |
+
self.k_proj = nn.Linear(enc_d_model, dec_d_model, bias=False)
|
| 210 |
+
self.v_proj = nn.Linear(enc_d_model, dec_d_model, bias=False)
|
| 211 |
+
self.o_proj = nn.Linear(dec_d_model, dec_d_model, bias=False)
|
| 212 |
+
|
| 213 |
+
def __call__(self, x: mx.array, encoder_output: mx.array,
|
| 214 |
+
encoder_mask: Optional[mx.array] = None,
|
| 215 |
+
cached_kv=None):
|
| 216 |
+
"""
|
| 217 |
+
cached_kv: precomputed (k, v) from encoder — computed once, reused every step.
|
| 218 |
+
"""
|
| 219 |
+
B, T, _ = x.shape
|
| 220 |
+
q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 221 |
+
|
| 222 |
+
if cached_kv is not None:
|
| 223 |
+
k, v = cached_kv
|
| 224 |
+
else:
|
| 225 |
+
T_enc = encoder_output.shape[1]
|
| 226 |
+
k = self.k_proj(encoder_output).reshape(B, T_enc, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 227 |
+
v = self.v_proj(encoder_output).reshape(B, T_enc, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 228 |
+
|
| 229 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
| 230 |
+
scores = (q @ k.transpose(0, 1, 3, 2)) * scale # [B, H, T, T_enc]
|
| 231 |
+
|
| 232 |
+
if encoder_mask is not None:
|
| 233 |
+
# encoder_mask: [B, T_enc], 1=real 0=pad
|
| 234 |
+
pad = (encoder_mask == 0).astype(mx.float32)
|
| 235 |
+
scores = scores + pad[:, None, None, :] * -1e9
|
| 236 |
+
|
| 237 |
+
attn = mx.softmax(scores.astype(mx.float32), axis=-1).astype(x.dtype)
|
| 238 |
+
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, -1)
|
| 239 |
+
return self.o_proj(out), (k, v)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class DecoderBlock(nn.Module):
|
| 243 |
+
def __init__(self, dec_d_model: int, enc_d_model: int, n_heads: int,
|
| 244 |
+
d_ff: int, max_len: int, tokens_per_frame: int = 1):
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.self_attn_norm = RMSNorm(dec_d_model)
|
| 247 |
+
self.self_attention = DecoderSelfAttention(dec_d_model, n_heads, max_len, tokens_per_frame)
|
| 248 |
+
self.cross_attn_norm = RMSNorm(dec_d_model)
|
| 249 |
+
self.cross_attention = CrossAttention(dec_d_model, enc_d_model, n_heads)
|
| 250 |
+
self.ffn_norm = RMSNorm(dec_d_model)
|
| 251 |
+
self.ffn = SwiGLUFFN(dec_d_model, d_ff)
|
| 252 |
+
|
| 253 |
+
def __call__(self, x: mx.array, encoder_output: mx.array,
|
| 254 |
+
encoder_mask=None, past_self_kv=None, cached_cross_kv=None,
|
| 255 |
+
offset: int = 0):
|
| 256 |
+
# 1. Causal self-attention
|
| 257 |
+
h = self.self_attn_norm(x)
|
| 258 |
+
sa_out, new_k, new_v = self.self_attention(h, past_self_kv, offset)
|
| 259 |
+
x = x + sa_out
|
| 260 |
+
|
| 261 |
+
# 2. Cross-attention (encoder KV cached after first call)
|
| 262 |
+
h = self.cross_attn_norm(x)
|
| 263 |
+
ca_out, cross_kv = self.cross_attention(h, encoder_output, encoder_mask, cached_cross_kv)
|
| 264 |
+
x = x + ca_out
|
| 265 |
+
|
| 266 |
+
# 3. FFN
|
| 267 |
+
x = x + self.ffn(self.ffn_norm(x))
|
| 268 |
+
|
| 269 |
+
return x, (new_k, new_v), cross_kv
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class AudioDecoder(nn.Module):
|
| 273 |
+
def __init__(self, vocab_size=DECODER_VOCAB_SIZE, d_model=DEC_D_MODEL,
|
| 274 |
+
enc_d_model=DEC_D_MODEL, n_heads=DEC_N_HEADS,
|
| 275 |
+
n_layers=DEC_N_LAYERS, d_ff=DEC_D_FF,
|
| 276 |
+
max_len=MAX_AUDIO_LEN, tokens_per_frame=1):
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.tokens_per_frame = tokens_per_frame
|
| 279 |
+
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
| 280 |
+
self.layers = [
|
| 281 |
+
DecoderBlock(d_model, enc_d_model, n_heads, d_ff, max_len, tokens_per_frame)
|
| 282 |
+
for _ in range(n_layers)
|
| 283 |
+
]
|
| 284 |
+
self.final_norm = RMSNorm(d_model)
|
| 285 |
+
# LM head tied to token_embedding (set in TTSEncoderDecoder)
|
| 286 |
+
|
| 287 |
+
def __call__(self, input_ids: mx.array, encoder_output: mx.array,
|
| 288 |
+
encoder_mask=None, past_key_values=None, cached_cross_kvs=None,
|
| 289 |
+
offset: int = 0):
|
| 290 |
+
"""
|
| 291 |
+
input_ids: [B, T]
|
| 292 |
+
encoder_output: [B, T_enc, d]
|
| 293 |
+
past_key_values: list of (k, v) per layer, or None
|
| 294 |
+
cached_cross_kvs: list of (k, v) per layer from encoder, or None
|
| 295 |
+
offset: token offset for RoPE (number of past tokens)
|
| 296 |
+
"""
|
| 297 |
+
h = self.token_embedding(input_ids)
|
| 298 |
+
|
| 299 |
+
new_self_kvs = []
|
| 300 |
+
new_cross_kvs = []
|
| 301 |
+
|
| 302 |
+
for i, layer in enumerate(self.layers):
|
| 303 |
+
past_self_kv = past_key_values[i] if past_key_values else None
|
| 304 |
+
cached_cross_kv = cached_cross_kvs[i] if cached_cross_kvs else None
|
| 305 |
+
|
| 306 |
+
h, new_self_kv, new_cross_kv = layer(
|
| 307 |
+
h, encoder_output, encoder_mask,
|
| 308 |
+
past_self_kv, cached_cross_kv, offset
|
| 309 |
+
)
|
| 310 |
+
new_self_kvs.append(new_self_kv)
|
| 311 |
+
new_cross_kvs.append(new_cross_kv)
|
| 312 |
+
|
| 313 |
+
h = self.final_norm(h)
|
| 314 |
+
# Tied embedding projection
|
| 315 |
+
logits = h @ self.token_embedding.weight.T
|
| 316 |
+
return logits, new_self_kvs, new_cross_kvs
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ── Full Model ─────────────────────────────────────────────────
|
| 320 |
+
|
| 321 |
+
@dataclass
|
| 322 |
+
class V5Config:
|
| 323 |
+
enc_vocab_size: int = ENCODER_VOCAB_SIZE
|
| 324 |
+
enc_d_model: int = ENC_D_MODEL
|
| 325 |
+
enc_n_heads: int = ENC_N_HEADS
|
| 326 |
+
enc_n_layers: int = ENC_N_LAYERS
|
| 327 |
+
enc_d_ff: int = ENC_D_FF
|
| 328 |
+
max_text_len: int = MAX_TEXT_LEN
|
| 329 |
+
dec_vocab_size: int = DECODER_VOCAB_SIZE
|
| 330 |
+
dec_d_model: int = DEC_D_MODEL
|
| 331 |
+
dec_n_heads: int = DEC_N_HEADS
|
| 332 |
+
dec_n_layers: int = DEC_N_LAYERS
|
| 333 |
+
dec_d_ff: int = DEC_D_FF
|
| 334 |
+
max_audio_len: int = MAX_AUDIO_LEN
|
| 335 |
+
dropout: float = 0.0
|
| 336 |
+
ctc_weight: float = 0.0
|
| 337 |
+
tokens_per_frame: int = 1
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class TTSEncoderDecoder(nn.Module):
|
| 341 |
+
def __init__(self, config: V5Config):
|
| 342 |
+
super().__init__()
|
| 343 |
+
self.config = config
|
| 344 |
+
|
| 345 |
+
self.encoder = TextEncoder(
|
| 346 |
+
vocab_size=config.enc_vocab_size,
|
| 347 |
+
d_model=config.enc_d_model,
|
| 348 |
+
n_heads=config.enc_n_heads,
|
| 349 |
+
n_layers=config.enc_n_layers,
|
| 350 |
+
d_ff=config.enc_d_ff,
|
| 351 |
+
max_len=config.max_text_len,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if config.enc_d_model != config.dec_d_model:
|
| 355 |
+
self.enc_projection = nn.Linear(config.enc_d_model, config.dec_d_model, bias=False)
|
| 356 |
+
else:
|
| 357 |
+
self.enc_projection = None
|
| 358 |
+
|
| 359 |
+
self.decoder = AudioDecoder(
|
| 360 |
+
vocab_size=config.dec_vocab_size,
|
| 361 |
+
d_model=config.dec_d_model,
|
| 362 |
+
enc_d_model=config.dec_d_model,
|
| 363 |
+
n_heads=config.dec_n_heads,
|
| 364 |
+
n_layers=config.dec_n_layers,
|
| 365 |
+
d_ff=config.dec_d_ff,
|
| 366 |
+
max_len=config.max_audio_len,
|
| 367 |
+
tokens_per_frame=config.tokens_per_frame,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def encode(self, enc_ids: mx.array, enc_mask=None) -> mx.array:
|
| 371 |
+
"""Run encoder + projection once. Returns [B, T_enc, dec_d_model]."""
|
| 372 |
+
enc_out = self.encoder(enc_ids, enc_mask)
|
| 373 |
+
if self.enc_projection is not None:
|
| 374 |
+
enc_out = self.enc_projection(enc_out)
|
| 375 |
+
return enc_out
|
tts_mlx/tokenizer.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tokenizer — identical to original tts_v5/tokenizer.py
|
| 3 |
+
Pure Python + numpy, no PyTorch dependency.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from .config import (
|
| 11 |
+
TEXT_CHARS, TEXT_OFFSET, AUDIO_OFFSET,
|
| 12 |
+
SPECIAL_TOKENS, NUM_SPECIAL_TOKENS,
|
| 13 |
+
TOTAL_VOCAB_SIZE, CODEC_CODEBOOK_SIZE,
|
| 14 |
+
PAD_TOKEN_ID, START_OF_TEXT_TOKEN_ID, END_OF_TEXT_TOKEN_ID,
|
| 15 |
+
START_OF_SPEECH_TOKEN_ID, END_OF_SPEECH_TOKEN_ID,
|
| 16 |
+
is_audio_token, is_special_token, is_text_token,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TTSTokenizer:
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.char2id: dict[str, int] = {}
|
| 23 |
+
self.id2char: dict[int, str] = {}
|
| 24 |
+
for i, ch in enumerate(TEXT_CHARS):
|
| 25 |
+
tid = TEXT_OFFSET + i
|
| 26 |
+
self.char2id[ch] = tid
|
| 27 |
+
self.id2char[tid] = ch
|
| 28 |
+
|
| 29 |
+
self._special_id_to_name = {v: k for k, v in SPECIAL_TOKENS.items()}
|
| 30 |
+
self.vocab_size = TOTAL_VOCAB_SIZE
|
| 31 |
+
self.text_vocab_size = len(TEXT_CHARS)
|
| 32 |
+
|
| 33 |
+
def normalize_text(self, text: str) -> str:
|
| 34 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 35 |
+
text = re.sub(r'[–—]', '-', text)
|
| 36 |
+
text = re.sub(r'[«»„""]', '"', text)
|
| 37 |
+
return text
|
| 38 |
+
|
| 39 |
+
def encode_text(self, text: str) -> list[int]:
|
| 40 |
+
text = self.normalize_text(text)
|
| 41 |
+
return [self.char2id[ch] for ch in text if ch in self.char2id]
|
| 42 |
+
|
| 43 |
+
def build_encoder_input(self, text: str, speaker_id: int = 0) -> np.ndarray:
|
| 44 |
+
"""Encoder input: <sot> text_chars <eot> <spk_X>"""
|
| 45 |
+
text_ids = self.encode_text(text)
|
| 46 |
+
spk = SPECIAL_TOKENS[f"<spk_{speaker_id}>"]
|
| 47 |
+
seq = [START_OF_TEXT_TOKEN_ID] + text_ids + [END_OF_TEXT_TOKEN_ID, spk]
|
| 48 |
+
return np.array(seq, dtype=np.int32)
|