raditotev commited on
Commit
bcdd523
·
verified ·
1 Parent(s): 470a4ae

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,53 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - bg
4
+ tags:
5
+ - text-to-speech
6
+ - bulgarian
7
+ - mlx
8
+ - apple-silicon
9
+ library_name: mlx
10
+ license: mit
11
+ ---
12
+
13
+ # 🇧🇬 BG-TTS V5 — MLX (Apple Silicon)
14
+
15
+ Native MLX port of [beleata74/bg-tts-v5](https://huggingface.co/beleata74/bg-tts-v5) for Apple Silicon (M1/M2/M3/M4).
16
+
17
+ No CUDA, no NeMo, no PyTorch required. Runs fully on Apple Silicon via MLX.
18
+
19
+ ## Requirements
20
+
21
+ ```bash
22
+ pip install mlx soundfile numpy
23
+ pip install "nanocodec-mlx @ git+https://github.com/nineninesix-ai/nanocodec-mlx.git"
24
+ ```
25
+
26
+ ## Quick Start
27
+
28
+ ```python
29
+ from tts_mlx.inference import synthesize
30
+
31
+ synthesize(
32
+ checkpoint=".", # path to this repo
33
+ text="Здравейте, аз съм българска система за синтез на реч.",
34
+ output="output.wav",
35
+ speaker_id=0, # 0 = AI voice, 1 = audiobook narrator
36
+ temperature=0.25,
37
+ top_k=50,
38
+ top_p=0.8,
39
+ )
40
+ ```
41
+
42
+ ## Speakers
43
+
44
+ | Speaker | Description | Best text length |
45
+ |---------|-------------|-----------------|
46
+ | 0 | AI-generated, clear & fast | Any (20–500+ chars) |
47
+ | 1 | Real female, audiobook narrator | 250–320 chars |
48
+
49
+ ## Credits
50
+
51
+ Original model by [beleata74](https://huggingface.co/beleata74/bg-tts-v5), created with Claude.
52
+ MLX port by Radi Totev.
53
+ NanoCodec MLX by [nineninesix-ai](https://github.com/nineninesix-ai/nanocodec-mlx).
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "bg-tts-v5-mlx",
3
+ "framework": "mlx",
4
+ "language": "bg",
5
+ "encoder": {
6
+ "vocab_size": 155,
7
+ "d_model": 512,
8
+ "n_heads": 8,
9
+ "n_layers": 6,
10
+ "d_ff": 2048,
11
+ "max_len": 512
12
+ },
13
+ "decoder": {
14
+ "vocab_size": 16283,
15
+ "d_model": 768,
16
+ "n_heads": 12,
17
+ "n_layers": 18,
18
+ "d_ff": 3072,
19
+ "max_len": 2048,
20
+ "tokens_per_frame": 4
21
+ },
22
+ "codec": {
23
+ "model": "nineninesix/nemo-nano-codec-22khz-0.6kbps-12.5fps-MLX",
24
+ "sample_rate": 22050,
25
+ "num_codebooks": 4,
26
+ "codebook_size": 4032,
27
+ "frame_rate": 12.5
28
+ },
29
+ "speakers": {
30
+ "0": "AI-generated female voice, clear and fast",
31
+ "1": "Real female voice, audiobook narrator (use 250-320 chars)"
32
+ }
33
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e473ad925047a7300f80fdb98afa9c80f7d1ab6b4e0f81a726e05d08738d38d
3
+ size 1003201311
tts_mlx/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .inference import synthesize, load_from_pytorch_checkpoint
2
+ from .model import TTSEncoderDecoder, V5Config
3
+ from .tokenizer import TTSTokenizer
tts_mlx/config.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Identical to original tts_v5/config.py — no changes needed
2
+
3
+ NANOCODEC_MODEL_NAME = "nineninesix/nemo-nano-codec-22khz-0.6kbps-12.5fps-MLX" # MLX version
4
+ CODEC_SAMPLE_RATE = 22_050
5
+ CODEC_NUM_CODEBOOKS = 4
6
+ CODEC_CODEBOOK_SIZE = 4_032
7
+ CODEC_FRAME_RATE = 12.5
8
+ CODEC_TOKENS_PER_SEC = 50
9
+ TOKENS_PER_FRAME = 4
10
+
11
+ BG_LOWER = "абвгдежзийклмнопрстуфхцчшщъьюя"
12
+ BG_UPPER = "АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯ"
13
+ EN_LOWER = "abcdefghijklmnopqrstuvwxyz"
14
+ EN_UPPER = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
15
+ DIGITS = "0123456789"
16
+ PUNCT = '.,!?;:-–—…"\'()[]{}«»„"" '
17
+ EXTRA = "\n\t"
18
+
19
+ _ALL_CHARS: list[str] = []
20
+ _seen: set[str] = set()
21
+ for _src in [BG_LOWER, BG_UPPER, EN_LOWER, EN_UPPER, DIGITS, PUNCT, EXTRA]:
22
+ for _ch in _src:
23
+ if _ch not in _seen:
24
+ _ALL_CHARS.append(_ch)
25
+ _seen.add(_ch)
26
+
27
+ SPECIAL_TOKENS = {
28
+ "<pad>": 0,
29
+ "<start_of_text>": 1,
30
+ "<end_of_text>": 2,
31
+ "<start_of_speech>": 3,
32
+ "<end_of_speech>": 4,
33
+ "<spk_0>": 5,
34
+ "<spk_1>": 6,
35
+ "<spk_2>": 7,
36
+ "<spk_3>": 8,
37
+ }
38
+ NUM_SPECIAL_TOKENS = len(SPECIAL_TOKENS)
39
+
40
+ TEXT_CHARS = _ALL_CHARS
41
+ TEXT_VOCAB_SIZE = len(TEXT_CHARS)
42
+ TEXT_OFFSET = NUM_SPECIAL_TOKENS
43
+ AUDIO_OFFSET = TEXT_OFFSET + TEXT_VOCAB_SIZE
44
+ NUM_AUDIO_TOKENS = CODEC_NUM_CODEBOOKS * CODEC_CODEBOOK_SIZE
45
+ TOTAL_VOCAB_SIZE = AUDIO_OFFSET + NUM_AUDIO_TOKENS
46
+
47
+ ENCODER_VOCAB_SIZE = AUDIO_OFFSET
48
+ DECODER_VOCAB_SIZE = TOTAL_VOCAB_SIZE
49
+
50
+ PAD_TOKEN_ID = SPECIAL_TOKENS["<pad>"]
51
+ START_OF_TEXT_TOKEN_ID = SPECIAL_TOKENS["<start_of_text>"]
52
+ END_OF_TEXT_TOKEN_ID = SPECIAL_TOKENS["<end_of_text>"]
53
+ START_OF_SPEECH_TOKEN_ID = SPECIAL_TOKENS["<start_of_speech>"]
54
+ END_OF_SPEECH_TOKEN_ID = SPECIAL_TOKENS["<end_of_speech>"]
55
+ SPK_0_TOKEN_ID = SPECIAL_TOKENS["<spk_0>"]
56
+ SPK_1_TOKEN_ID = SPECIAL_TOKENS["<spk_1>"]
57
+
58
+ def audio_token_id(codebook: int, code: int) -> int:
59
+ return AUDIO_OFFSET + codebook * CODEC_CODEBOOK_SIZE + code
60
+
61
+ def decode_audio_token(token_id: int) -> tuple[int, int]:
62
+ offset = token_id - AUDIO_OFFSET
63
+ return offset // CODEC_CODEBOOK_SIZE, offset % CODEC_CODEBOOK_SIZE
64
+
65
+ def is_audio_token(token_id: int) -> bool:
66
+ return AUDIO_OFFSET <= token_id < AUDIO_OFFSET + NUM_AUDIO_TOKENS
67
+
68
+ def is_special_token(token_id: int) -> bool:
69
+ return 0 <= token_id < NUM_SPECIAL_TOKENS
70
+
71
+ def is_text_token(token_id: int) -> bool:
72
+ return TEXT_OFFSET <= token_id < AUDIO_OFFSET
73
+
74
+ ENC_D_MODEL = 512
75
+ ENC_N_HEADS = 8
76
+ ENC_N_LAYERS = 6
77
+ ENC_D_FF = 2048
78
+
79
+ DEC_D_MODEL = 768
80
+ DEC_N_HEADS = 12
81
+ DEC_N_LAYERS = 18
82
+ DEC_D_FF = 3072
83
+
84
+ MAX_TEXT_LEN = 512
85
+ MAX_AUDIO_LEN = 2048
86
+ DROPOUT = 0.10
87
+ CTC_WEIGHT = 0.1
88
+
89
+ BATCH_SIZE = 8
90
+ GRAD_ACCUM = 2
91
+ LR = 3e-4
92
+ WEIGHT_DECAY = 0.1
93
+ WARMUP_STEPS = 500
94
+ NUM_EPOCHS = 3
tts_mlx/inference.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MLX Inference — Encoder-Decoder TTS
3
+ =====================================
4
+ 1. Load PyTorch checkpoint weights → convert to MLX arrays
5
+ 2. Encode text with encoder (once, bidirectional)
6
+ 3. Cache cross-attention KVs from encoder (computed once per layer)
7
+ 4. Autoregressively decode audio tokens
8
+ 5. Decode tokens → wav using nanocodec-mlx
9
+ """
10
+
11
+ import os
12
+ import math
13
+ import numpy as np
14
+ import mlx.core as mx
15
+ import soundfile as sf
16
+
17
+ from .config import (
18
+ AUDIO_OFFSET, NUM_AUDIO_TOKENS, END_OF_SPEECH_TOKEN_ID,
19
+ START_OF_SPEECH_TOKEN_ID, CODEC_NUM_CODEBOOKS, NANOCODEC_MODEL_NAME,
20
+ )
21
+ from .tokenizer import TTSTokenizer
22
+ from .model import TTSEncoderDecoder, V5Config
23
+
24
+
25
+ # ── Weight Loading ─────────────────────────────────────────────
26
+
27
+ def _pt_to_mx(t):
28
+ """Convert PyTorch tensor to MLX array."""
29
+ return mx.array(t.float().numpy())
30
+
31
+
32
+ def load_from_pytorch_checkpoint(checkpoint_path: str) -> TTSEncoderDecoder:
33
+ """
34
+ Load PyTorch checkpoint and convert weights to MLX.
35
+ """
36
+ import torch # only needed when loading from PyTorch checkpoint
37
+ ckpt_file = os.path.join(checkpoint_path, "checkpoint.pt")
38
+ print(f"📂 Loading checkpoint: {ckpt_file}")
39
+ ckpt = torch.load(ckpt_file, map_location="cpu", weights_only=False)
40
+
41
+ cfg = ckpt["config"]
42
+ config = V5Config(
43
+ enc_vocab_size=cfg["enc_vocab_size"],
44
+ enc_d_model=cfg["enc_d_model"],
45
+ enc_n_heads=cfg["enc_n_heads"],
46
+ enc_n_layers=cfg["enc_n_layers"],
47
+ enc_d_ff=cfg["enc_d_ff"],
48
+ max_text_len=cfg["max_text_len"],
49
+ dec_vocab_size=cfg["dec_vocab_size"],
50
+ dec_d_model=cfg["dec_d_model"],
51
+ dec_n_heads=cfg["dec_n_heads"],
52
+ dec_n_layers=cfg["dec_n_layers"],
53
+ dec_d_ff=cfg["dec_d_ff"],
54
+ max_audio_len=cfg["max_audio_len"],
55
+ dropout=0.0,
56
+ ctc_weight=0.0,
57
+ tokens_per_frame=cfg.get("tokens_per_frame", 1),
58
+ )
59
+
60
+ model = TTSEncoderDecoder(config)
61
+ state = ckpt["model_state_dict"]
62
+
63
+ # Build MLX weight dict by mapping PyTorch keys → MLX parameter paths
64
+ mlx_weights = {}
65
+ for key, val in state.items():
66
+ # Skip CTC head (not needed for inference)
67
+ if key.startswith("ctc_head"):
68
+ continue
69
+ # Convert key format: PyTorch uses dots, MLX uses same
70
+ # e.g. "encoder.layers.0.attention.q_proj.weight" stays the same
71
+ mlx_weights[key] = _pt_to_mx(val)
72
+
73
+ model.load_weights(list(mlx_weights.items()), strict=False)
74
+ mx.eval(model.parameters())
75
+
76
+ step = ckpt.get("step", "?")
77
+ loss = ckpt.get("loss", 0.0)
78
+ print(f"✅ Loaded! step={step}, loss={loss:.4f}, tpf={config.tokens_per_frame}")
79
+ return model
80
+
81
+ def load_from_safetensors(repo_path: str) -> TTSEncoderDecoder:
82
+ """Load MLX model from safetensors — no PyTorch required."""
83
+ import json
84
+ weights_file = os.path.join(repo_path, "model.safetensors")
85
+ config_file = os.path.join(repo_path, "config.json")
86
+
87
+ with open(config_file) as f:
88
+ cfg = json.load(f)
89
+
90
+ config = V5Config(
91
+ enc_vocab_size=cfg["encoder"]["vocab_size"],
92
+ enc_d_model=cfg["encoder"]["d_model"],
93
+ enc_n_heads=cfg["encoder"]["n_heads"],
94
+ enc_n_layers=cfg["encoder"]["n_layers"],
95
+ enc_d_ff=cfg["encoder"]["d_ff"],
96
+ max_text_len=cfg["encoder"]["max_len"],
97
+ dec_vocab_size=cfg["decoder"]["vocab_size"],
98
+ dec_d_model=cfg["decoder"]["d_model"],
99
+ dec_n_heads=cfg["decoder"]["n_heads"],
100
+ dec_n_layers=cfg["decoder"]["n_layers"],
101
+ dec_d_ff=cfg["decoder"]["d_ff"],
102
+ max_audio_len=cfg["decoder"]["max_len"],
103
+ tokens_per_frame=cfg["decoder"]["tokens_per_frame"],
104
+ dropout=0.0,
105
+ ctc_weight=0.0,
106
+ )
107
+ model = TTSEncoderDecoder(config)
108
+ model.load_weights(weights_file, strict=False)
109
+ mx.eval(model.parameters())
110
+ print(f"✅ Loaded from safetensors!")
111
+ return model
112
+
113
+ # ── Generation ─────────────────────────────────────────────────
114
+
115
+ def sample_token(logits: mx.array, temperature: float, top_k: int, top_p: float,
116
+ recent_tokens: list, rep_penalty: float) -> int:
117
+ """Sample next token from logits."""
118
+ # Mask: only audio tokens + eos allowed
119
+ mask = mx.full(logits.shape, -1e9)
120
+ # Allow audio tokens
121
+ audio_slice = mx.zeros((NUM_AUDIO_TOKENS,))
122
+ mask = mx.concatenate([
123
+ mask[:AUDIO_OFFSET],
124
+ audio_slice,
125
+ ], axis=0)
126
+ # Allow end of speech
127
+ eos_allow = mx.zeros((1,))
128
+ mask_list = mx.array(
129
+ [-1e9] * AUDIO_OFFSET +
130
+ [0.0] * NUM_AUDIO_TOKENS
131
+ )
132
+ # Simpler: build as numpy, set allowed positions
133
+ mask_np = np.full(logits.shape, -1e9, dtype=np.float32)
134
+ mask_np[AUDIO_OFFSET: AUDIO_OFFSET + NUM_AUDIO_TOKENS] = 0.0
135
+ mask_np[END_OF_SPEECH_TOKEN_ID] = 0.0
136
+ logits_np = np.array(logits) + mask_np
137
+
138
+ # Repetition penalty
139
+ if rep_penalty != 1.0 and recent_tokens:
140
+ for tid in set(recent_tokens[-200:]):
141
+ if AUDIO_OFFSET <= tid < AUDIO_OFFSET + NUM_AUDIO_TOKENS:
142
+ logits_np[tid] /= rep_penalty
143
+
144
+ # Temperature
145
+ logits_np = logits_np / temperature
146
+
147
+ # Top-k
148
+ if top_k > 0:
149
+ kth_val = np.partition(logits_np, -min(top_k, len(logits_np)))[-min(top_k, len(logits_np))]
150
+ logits_np[logits_np < kth_val] = -1e9
151
+
152
+ # Top-p
153
+ if top_p < 1.0:
154
+ sorted_idx = np.argsort(logits_np)[::-1]
155
+ sorted_logits = logits_np[sorted_idx]
156
+ probs = np.exp(sorted_logits - sorted_logits[0])
157
+ probs /= probs.sum()
158
+ cum = np.cumsum(probs)
159
+ remove = cum > top_p
160
+ remove[1:] = remove[:-1].copy()
161
+ remove[0] = False
162
+ logits_np[sorted_idx[remove]] = -1e9
163
+
164
+ # Sample
165
+ probs = np.exp(logits_np - logits_np.max())
166
+ probs /= probs.sum()
167
+ return int(np.random.choice(len(probs), p=probs))
168
+
169
+
170
+ def generate(model: TTSEncoderDecoder, tokenizer: TTSTokenizer,
171
+ text: str, speaker_id: int = 0,
172
+ max_new_tokens: int = 2000, temperature: float = 0.25,
173
+ top_k: int = 50, top_p: float = 0.8, rep_penalty: float = 1.1):
174
+ """Generate audio tokens from text."""
175
+
176
+ # 1. Encode text (once)
177
+ enc_ids_np = tokenizer.build_encoder_input(text, speaker_id)
178
+ enc_ids = mx.array(enc_ids_np[None, :]) # [1, T_enc]
179
+ enc_mask = mx.ones_like(enc_ids)
180
+
181
+ enc_out = model.encode(enc_ids, enc_mask) # [1, T_enc, dec_d]
182
+ mx.eval(enc_out)
183
+ print(f"📝 Encoded: {enc_ids.shape[1]} tokens → enc_out {enc_out.shape}")
184
+
185
+ # 2. Autoregressive decode
186
+ dec_ids = mx.array([[START_OF_SPEECH_TOKEN_ID]])
187
+ past_self_kvs = None
188
+ cached_cross_kvs = None
189
+ generated = []
190
+ offset = 0
191
+
192
+ for step in range(max_new_tokens):
193
+ inp = dec_ids[:, -1:] if past_self_kvs is not None else dec_ids
194
+
195
+ logits, new_self_kvs, new_cross_kvs = model.decoder(
196
+ inp, enc_out, enc_mask,
197
+ past_key_values=past_self_kvs,
198
+ cached_cross_kvs=cached_cross_kvs,
199
+ offset=offset,
200
+ )
201
+ mx.eval(logits)
202
+
203
+ # Cache cross-attention KVs after first step (they don't change)
204
+ if cached_cross_kvs is None:
205
+ cached_cross_kvs = new_cross_kvs
206
+ mx.eval(cached_cross_kvs)
207
+
208
+ past_self_kvs = new_self_kvs
209
+ offset += inp.shape[1]
210
+
211
+ # Sample
212
+ last_logits = np.array(logits[0, -1, :])
213
+ tok_id = sample_token(last_logits, temperature, top_k, top_p, generated, rep_penalty)
214
+
215
+ if tok_id == END_OF_SPEECH_TOKEN_ID:
216
+ print(f"🛑 EOS at step {step}")
217
+ break
218
+
219
+ generated.append(tok_id)
220
+ dec_ids = mx.array([[tok_id]])
221
+
222
+ if step % 100 == 0 and step > 0:
223
+ print(f" step {step}: {len(generated)} tokens (~{len(generated)/50:.1f}s audio)")
224
+
225
+ if not generated:
226
+ return None
227
+
228
+ tokens = np.array(generated)
229
+ audio_mask = (tokens >= AUDIO_OFFSET) & (tokens < AUDIO_OFFSET + NUM_AUDIO_TOKENS)
230
+ return tokens[audio_mask] - AUDIO_OFFSET
231
+
232
+
233
+ # ── Full Pipeline ──────────────────────────────────────────────
234
+
235
+ def synthesize(checkpoint: str, text: str, output: str = "output.wav",
236
+ speaker_id: int = 0, temperature: float = 0.25,
237
+ top_k: int = 50, top_p: float = 0.8, rep_penalty: float = 1.1,
238
+ max_tokens: int = 2000):
239
+
240
+ print(f"\n🎤 Text: '{text[:80]}'")
241
+ print(f" speaker={speaker_id}, T={temperature}, top_k={top_k}, top_p={top_p}")
242
+
243
+ # Load model
244
+ # Auto-detect: safetensors repo or PyTorch checkpoint?
245
+ if os.path.exists(os.path.join(checkpoint, "model.safetensors")):
246
+ model = load_from_safetensors(checkpoint)
247
+ else:
248
+ model = load_from_pytorch_checkpoint(checkpoint)
249
+ model.eval()
250
+
251
+ # Load tokenizer
252
+ tokenizer = TTSTokenizer()
253
+
254
+ # Generate tokens
255
+ tokens = generate(model, tokenizer, text, speaker_id, max_tokens,
256
+ temperature, top_k, top_p, rep_penalty)
257
+
258
+ if tokens is None or len(tokens) == 0:
259
+ print("❌ No audio generated!")
260
+ return
261
+
262
+ # Trim to multiple of 4 codebooks
263
+ tokens = tokens[:len(tokens) - len(tokens) % CODEC_NUM_CODEBOOKS]
264
+ print(f"🔊 {len(tokens)} tokens → {len(tokens)//4} frames → ~{len(tokens)//4/12.5:.1f}s audio")
265
+
266
+ # Decode with nanocodec-mlx
267
+ print("🎵 Decoding with NanoCodec MLX...")
268
+ from nanocodec_mlx.models.audio_codec import AudioCodecModel
269
+
270
+ codec = AudioCodecModel.from_pretrained(NANOCODEC_MODEL_NAME)
271
+
272
+ # Reshape tokens: [num_codebooks, num_frames]
273
+ num_frames = len(tokens) // CODEC_NUM_CODEBOOKS
274
+ codes = tokens.reshape(num_frames, CODEC_NUM_CODEBOOKS).T # [4, T]
275
+ codes_mx = mx.array(codes.astype(np.int32))[None, :, :] # [1, 4, T]
276
+ tokens_len = mx.array([num_frames], dtype=mx.int32)
277
+
278
+ wav_mx, _ = codec.decode(codes_mx, tokens_len)
279
+ mx.eval(wav_mx)
280
+
281
+ # Save
282
+ wav_np = np.array(wav_mx[0, 0, :])
283
+ sf.write(output, wav_np, 22050)
284
+ duration = len(wav_np) / 22050
285
+ print(f"✅ Saved: {output} ({duration:.2f}s)")
286
+ return wav_np
287
+
288
+
289
+ if __name__ == "__main__":
290
+ import argparse
291
+ p = argparse.ArgumentParser()
292
+ p.add_argument("--checkpoint", required=True)
293
+ p.add_argument("--text", required=True)
294
+ p.add_argument("--output", default="output.wav")
295
+ p.add_argument("--speaker", type=int, default=0)
296
+ p.add_argument("--temperature", type=float, default=0.25)
297
+ p.add_argument("--top-k", type=int, default=50)
298
+ p.add_argument("--top-p", type=float, default=0.8)
299
+ p.add_argument("--rep-penalty", type=float, default=1.1)
300
+ p.add_argument("--max-tokens", type=int, default=2000)
301
+ a = p.parse_args()
302
+ synthesize(a.checkpoint, a.text, a.output, a.speaker,
303
+ a.temperature, a.top_k, a.top_p, a.rep_penalty, a.max_tokens)
tts_mlx/model.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MLX Model — Encoder-Decoder TTS
3
+ ================================
4
+ Port of tts_v5/model.py from PyTorch to MLX.
5
+ Inference-only (no training, no dropout, no CTC head needed).
6
+ """
7
+
8
+ import math
9
+ import mlx.core as mx
10
+ import mlx.nn as nn
11
+ from dataclasses import dataclass
12
+ from typing import Optional, Tuple
13
+
14
+ from .config import (
15
+ TOTAL_VOCAB_SIZE, ENCODER_VOCAB_SIZE, DECODER_VOCAB_SIZE,
16
+ ENC_D_MODEL, ENC_N_HEADS, ENC_N_LAYERS, ENC_D_FF,
17
+ DEC_D_MODEL, DEC_N_HEADS, DEC_N_LAYERS, DEC_D_FF,
18
+ MAX_TEXT_LEN, MAX_AUDIO_LEN,
19
+ PAD_TOKEN_ID, NUM_AUDIO_TOKENS, AUDIO_OFFSET,
20
+ )
21
+
22
+
23
+ # ── Shared Components ──────────────────────────────────────────
24
+
25
+ class RMSNorm(nn.Module):
26
+ def __init__(self, dim: int, eps: float = 1e-6):
27
+ super().__init__()
28
+ self.eps = eps
29
+ self.weight = mx.ones((dim,))
30
+
31
+ def __call__(self, x: mx.array) -> mx.array:
32
+ norm = mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
33
+ return x * norm * self.weight
34
+
35
+
36
+ class SwiGLUFFN(nn.Module):
37
+ def __init__(self, d_model: int, d_ff: int):
38
+ super().__init__()
39
+ self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
40
+ self.up_proj = nn.Linear(d_model, d_ff, bias=False)
41
+ self.down_proj = nn.Linear(d_ff, d_model, bias=False)
42
+
43
+ def __call__(self, x: mx.array) -> mx.array:
44
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
45
+
46
+
47
+ def build_rope_cache(max_seq_len: int, head_dim: int, base: float = 10000.0):
48
+ """Precompute RoPE cos/sin tables."""
49
+ inv_freq = 1.0 / (base ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
50
+ t = mx.arange(max_seq_len, dtype=mx.float32)
51
+ freqs = mx.outer(t, inv_freq)
52
+ emb = mx.concatenate([freqs, freqs], axis=-1)
53
+ return mx.cos(emb), mx.sin(emb)
54
+
55
+
56
+ def rotate_half(x: mx.array) -> mx.array:
57
+ half = x.shape[-1] // 2
58
+ x1, x2 = x[..., :half], x[..., half:]
59
+ return mx.concatenate([-x2, x1], axis=-1)
60
+
61
+
62
+ def apply_rope(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array):
63
+ cos = cos[None, None, :, :] # [1, 1, T, head_dim]
64
+ sin = sin[None, None, :, :]
65
+ q = q * cos + rotate_half(q) * sin
66
+ k = k * cos + rotate_half(k) * sin
67
+ return q, k
68
+
69
+
70
+ # ── Encoder (Bidirectional) ────────────────────────────────────
71
+
72
+ class EncoderSelfAttention(nn.Module):
73
+ def __init__(self, d_model: int, n_heads: int):
74
+ super().__init__()
75
+ self.n_heads = n_heads
76
+ self.head_dim = d_model // n_heads
77
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
78
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
79
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
80
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
81
+
82
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
83
+ B, T, _ = x.shape
84
+ q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
85
+ k = self.k_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
86
+ v = self.v_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
87
+
88
+ scale = 1.0 / math.sqrt(self.head_dim)
89
+ scores = (q @ k.transpose(0, 1, 3, 2)) * scale # [B, H, T, T]
90
+
91
+ if mask is not None:
92
+ scores = scores + mask
93
+
94
+ attn = mx.softmax(scores.astype(mx.float32), axis=-1).astype(x.dtype)
95
+ out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, -1)
96
+ return self.o_proj(out)
97
+
98
+
99
+ class EncoderBlock(nn.Module):
100
+ def __init__(self, d_model: int, n_heads: int, d_ff: int):
101
+ super().__init__()
102
+ self.attn_norm = RMSNorm(d_model)
103
+ self.attention = EncoderSelfAttention(d_model, n_heads)
104
+ self.ffn_norm = RMSNorm(d_model)
105
+ self.ffn = SwiGLUFFN(d_model, d_ff)
106
+
107
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
108
+ x = x + self.attention(self.attn_norm(x), mask)
109
+ x = x + self.ffn(self.ffn_norm(x))
110
+ return x
111
+
112
+
113
+ class TextEncoder(nn.Module):
114
+ def __init__(self, vocab_size=ENCODER_VOCAB_SIZE, d_model=ENC_D_MODEL,
115
+ n_heads=ENC_N_HEADS, n_layers=ENC_N_LAYERS, d_ff=ENC_D_FF,
116
+ max_len=MAX_TEXT_LEN):
117
+ super().__init__()
118
+ self.d_model = d_model
119
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
120
+ self.pos_embedding = nn.Embedding(max_len, d_model)
121
+ self.layers = [EncoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)]
122
+ self.final_norm = RMSNorm(d_model)
123
+
124
+ def __call__(self, input_ids: mx.array, attention_mask: Optional[mx.array] = None) -> mx.array:
125
+ B, T = input_ids.shape
126
+ pos = mx.arange(T)[None, :] # [1, T]
127
+ h = self.token_embedding(input_ids) + self.pos_embedding(pos)
128
+
129
+ # Build padding mask: [B, 1, 1, T], -inf on pad positions
130
+ attn_mask = None
131
+ if attention_mask is not None:
132
+ # attention_mask: [B, T], 1=real 0=pad
133
+ pad = (attention_mask == 0).astype(mx.float32) # [B, T]
134
+ attn_mask = pad[:, None, None, :] * -1e9 # [B, 1, 1, T]
135
+
136
+ for layer in self.layers:
137
+ h = layer(h, attn_mask)
138
+
139
+ return self.final_norm(h)
140
+
141
+
142
+ # ── Decoder (Causal with Cross-Attention) ──────────────────────
143
+
144
+ class DecoderSelfAttention(nn.Module):
145
+ def __init__(self, d_model: int, n_heads: int, max_len: int, tokens_per_frame: int = 1):
146
+ super().__init__()
147
+ self.n_heads = n_heads
148
+ self.head_dim = d_model // n_heads
149
+ self.tokens_per_frame = tokens_per_frame
150
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
151
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
152
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
153
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
154
+ # Precompute RoPE
155
+ cos, sin = build_rope_cache(max_len * 2, self.head_dim)
156
+ self.rope_cos = cos
157
+ self.rope_sin = sin
158
+
159
+ def __call__(self, x: mx.array, past_kv=None, offset: int = 0):
160
+ """
161
+ x: [B, T, d_model]
162
+ past_kv: (k_cache, v_cache) or None
163
+ offset: number of already-generated tokens (for RoPE position)
164
+ Returns: (output, new_k, new_v)
165
+ """
166
+ B, T, _ = x.shape
167
+ q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
168
+ k = self.k_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
169
+ v = self.v_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
170
+
171
+ # Apply RoPE with frame-level positions
172
+ if self.tokens_per_frame > 1:
173
+ frame_offset = offset // self.tokens_per_frame
174
+ frame_positions = mx.arange(T) // self.tokens_per_frame + frame_offset
175
+ else:
176
+ frame_positions = mx.arange(T) + offset
177
+
178
+ cos = self.rope_cos[frame_positions] # [T, head_dim]
179
+ sin = self.rope_sin[frame_positions]
180
+ q, k = apply_rope(q, k, cos, sin)
181
+
182
+ # Append to KV cache
183
+ if past_kv is not None:
184
+ k = mx.concatenate([past_kv[0], k], axis=2)
185
+ v = mx.concatenate([past_kv[1], v], axis=2)
186
+
187
+ new_k, new_v = k, v
188
+
189
+ # Causal mask only during prefill (T > 1, no cache)
190
+ scale = 1.0 / math.sqrt(self.head_dim)
191
+ scores = (q @ k.transpose(0, 1, 3, 2)) * scale
192
+
193
+ if past_kv is None and T > 1:
194
+ # Build causal mask
195
+ causal = mx.triu(mx.full((T, k.shape[2]), -1e9), k=1)
196
+ scores = scores + causal[None, None, :, :]
197
+
198
+ attn = mx.softmax(scores.astype(mx.float32), axis=-1).astype(x.dtype)
199
+ out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, -1)
200
+ return self.o_proj(out), new_k, new_v
201
+
202
+
203
+ class CrossAttention(nn.Module):
204
+ def __init__(self, dec_d_model: int, enc_d_model: int, n_heads: int):
205
+ super().__init__()
206
+ self.n_heads = n_heads
207
+ self.head_dim = dec_d_model // n_heads
208
+ self.q_proj = nn.Linear(dec_d_model, dec_d_model, bias=False)
209
+ self.k_proj = nn.Linear(enc_d_model, dec_d_model, bias=False)
210
+ self.v_proj = nn.Linear(enc_d_model, dec_d_model, bias=False)
211
+ self.o_proj = nn.Linear(dec_d_model, dec_d_model, bias=False)
212
+
213
+ def __call__(self, x: mx.array, encoder_output: mx.array,
214
+ encoder_mask: Optional[mx.array] = None,
215
+ cached_kv=None):
216
+ """
217
+ cached_kv: precomputed (k, v) from encoder — computed once, reused every step.
218
+ """
219
+ B, T, _ = x.shape
220
+ q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
221
+
222
+ if cached_kv is not None:
223
+ k, v = cached_kv
224
+ else:
225
+ T_enc = encoder_output.shape[1]
226
+ k = self.k_proj(encoder_output).reshape(B, T_enc, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
227
+ v = self.v_proj(encoder_output).reshape(B, T_enc, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
228
+
229
+ scale = 1.0 / math.sqrt(self.head_dim)
230
+ scores = (q @ k.transpose(0, 1, 3, 2)) * scale # [B, H, T, T_enc]
231
+
232
+ if encoder_mask is not None:
233
+ # encoder_mask: [B, T_enc], 1=real 0=pad
234
+ pad = (encoder_mask == 0).astype(mx.float32)
235
+ scores = scores + pad[:, None, None, :] * -1e9
236
+
237
+ attn = mx.softmax(scores.astype(mx.float32), axis=-1).astype(x.dtype)
238
+ out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, -1)
239
+ return self.o_proj(out), (k, v)
240
+
241
+
242
+ class DecoderBlock(nn.Module):
243
+ def __init__(self, dec_d_model: int, enc_d_model: int, n_heads: int,
244
+ d_ff: int, max_len: int, tokens_per_frame: int = 1):
245
+ super().__init__()
246
+ self.self_attn_norm = RMSNorm(dec_d_model)
247
+ self.self_attention = DecoderSelfAttention(dec_d_model, n_heads, max_len, tokens_per_frame)
248
+ self.cross_attn_norm = RMSNorm(dec_d_model)
249
+ self.cross_attention = CrossAttention(dec_d_model, enc_d_model, n_heads)
250
+ self.ffn_norm = RMSNorm(dec_d_model)
251
+ self.ffn = SwiGLUFFN(dec_d_model, d_ff)
252
+
253
+ def __call__(self, x: mx.array, encoder_output: mx.array,
254
+ encoder_mask=None, past_self_kv=None, cached_cross_kv=None,
255
+ offset: int = 0):
256
+ # 1. Causal self-attention
257
+ h = self.self_attn_norm(x)
258
+ sa_out, new_k, new_v = self.self_attention(h, past_self_kv, offset)
259
+ x = x + sa_out
260
+
261
+ # 2. Cross-attention (encoder KV cached after first call)
262
+ h = self.cross_attn_norm(x)
263
+ ca_out, cross_kv = self.cross_attention(h, encoder_output, encoder_mask, cached_cross_kv)
264
+ x = x + ca_out
265
+
266
+ # 3. FFN
267
+ x = x + self.ffn(self.ffn_norm(x))
268
+
269
+ return x, (new_k, new_v), cross_kv
270
+
271
+
272
+ class AudioDecoder(nn.Module):
273
+ def __init__(self, vocab_size=DECODER_VOCAB_SIZE, d_model=DEC_D_MODEL,
274
+ enc_d_model=DEC_D_MODEL, n_heads=DEC_N_HEADS,
275
+ n_layers=DEC_N_LAYERS, d_ff=DEC_D_FF,
276
+ max_len=MAX_AUDIO_LEN, tokens_per_frame=1):
277
+ super().__init__()
278
+ self.tokens_per_frame = tokens_per_frame
279
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
280
+ self.layers = [
281
+ DecoderBlock(d_model, enc_d_model, n_heads, d_ff, max_len, tokens_per_frame)
282
+ for _ in range(n_layers)
283
+ ]
284
+ self.final_norm = RMSNorm(d_model)
285
+ # LM head tied to token_embedding (set in TTSEncoderDecoder)
286
+
287
+ def __call__(self, input_ids: mx.array, encoder_output: mx.array,
288
+ encoder_mask=None, past_key_values=None, cached_cross_kvs=None,
289
+ offset: int = 0):
290
+ """
291
+ input_ids: [B, T]
292
+ encoder_output: [B, T_enc, d]
293
+ past_key_values: list of (k, v) per layer, or None
294
+ cached_cross_kvs: list of (k, v) per layer from encoder, or None
295
+ offset: token offset for RoPE (number of past tokens)
296
+ """
297
+ h = self.token_embedding(input_ids)
298
+
299
+ new_self_kvs = []
300
+ new_cross_kvs = []
301
+
302
+ for i, layer in enumerate(self.layers):
303
+ past_self_kv = past_key_values[i] if past_key_values else None
304
+ cached_cross_kv = cached_cross_kvs[i] if cached_cross_kvs else None
305
+
306
+ h, new_self_kv, new_cross_kv = layer(
307
+ h, encoder_output, encoder_mask,
308
+ past_self_kv, cached_cross_kv, offset
309
+ )
310
+ new_self_kvs.append(new_self_kv)
311
+ new_cross_kvs.append(new_cross_kv)
312
+
313
+ h = self.final_norm(h)
314
+ # Tied embedding projection
315
+ logits = h @ self.token_embedding.weight.T
316
+ return logits, new_self_kvs, new_cross_kvs
317
+
318
+
319
+ # ── Full Model ─────────────────────────────────────────────────
320
+
321
+ @dataclass
322
+ class V5Config:
323
+ enc_vocab_size: int = ENCODER_VOCAB_SIZE
324
+ enc_d_model: int = ENC_D_MODEL
325
+ enc_n_heads: int = ENC_N_HEADS
326
+ enc_n_layers: int = ENC_N_LAYERS
327
+ enc_d_ff: int = ENC_D_FF
328
+ max_text_len: int = MAX_TEXT_LEN
329
+ dec_vocab_size: int = DECODER_VOCAB_SIZE
330
+ dec_d_model: int = DEC_D_MODEL
331
+ dec_n_heads: int = DEC_N_HEADS
332
+ dec_n_layers: int = DEC_N_LAYERS
333
+ dec_d_ff: int = DEC_D_FF
334
+ max_audio_len: int = MAX_AUDIO_LEN
335
+ dropout: float = 0.0
336
+ ctc_weight: float = 0.0
337
+ tokens_per_frame: int = 1
338
+
339
+
340
+ class TTSEncoderDecoder(nn.Module):
341
+ def __init__(self, config: V5Config):
342
+ super().__init__()
343
+ self.config = config
344
+
345
+ self.encoder = TextEncoder(
346
+ vocab_size=config.enc_vocab_size,
347
+ d_model=config.enc_d_model,
348
+ n_heads=config.enc_n_heads,
349
+ n_layers=config.enc_n_layers,
350
+ d_ff=config.enc_d_ff,
351
+ max_len=config.max_text_len,
352
+ )
353
+
354
+ if config.enc_d_model != config.dec_d_model:
355
+ self.enc_projection = nn.Linear(config.enc_d_model, config.dec_d_model, bias=False)
356
+ else:
357
+ self.enc_projection = None
358
+
359
+ self.decoder = AudioDecoder(
360
+ vocab_size=config.dec_vocab_size,
361
+ d_model=config.dec_d_model,
362
+ enc_d_model=config.dec_d_model,
363
+ n_heads=config.dec_n_heads,
364
+ n_layers=config.dec_n_layers,
365
+ d_ff=config.dec_d_ff,
366
+ max_len=config.max_audio_len,
367
+ tokens_per_frame=config.tokens_per_frame,
368
+ )
369
+
370
+ def encode(self, enc_ids: mx.array, enc_mask=None) -> mx.array:
371
+ """Run encoder + projection once. Returns [B, T_enc, dec_d_model]."""
372
+ enc_out = self.encoder(enc_ids, enc_mask)
373
+ if self.enc_projection is not None:
374
+ enc_out = self.enc_projection(enc_out)
375
+ return enc_out
tts_mlx/tokenizer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tokenizer — identical to original tts_v5/tokenizer.py
3
+ Pure Python + numpy, no PyTorch dependency.
4
+ """
5
+
6
+ import re
7
+ import numpy as np
8
+ from typing import Optional
9
+
10
+ from .config import (
11
+ TEXT_CHARS, TEXT_OFFSET, AUDIO_OFFSET,
12
+ SPECIAL_TOKENS, NUM_SPECIAL_TOKENS,
13
+ TOTAL_VOCAB_SIZE, CODEC_CODEBOOK_SIZE,
14
+ PAD_TOKEN_ID, START_OF_TEXT_TOKEN_ID, END_OF_TEXT_TOKEN_ID,
15
+ START_OF_SPEECH_TOKEN_ID, END_OF_SPEECH_TOKEN_ID,
16
+ is_audio_token, is_special_token, is_text_token,
17
+ )
18
+
19
+
20
+ class TTSTokenizer:
21
+ def __init__(self):
22
+ self.char2id: dict[str, int] = {}
23
+ self.id2char: dict[int, str] = {}
24
+ for i, ch in enumerate(TEXT_CHARS):
25
+ tid = TEXT_OFFSET + i
26
+ self.char2id[ch] = tid
27
+ self.id2char[tid] = ch
28
+
29
+ self._special_id_to_name = {v: k for k, v in SPECIAL_TOKENS.items()}
30
+ self.vocab_size = TOTAL_VOCAB_SIZE
31
+ self.text_vocab_size = len(TEXT_CHARS)
32
+
33
+ def normalize_text(self, text: str) -> str:
34
+ text = re.sub(r'\s+', ' ', text).strip()
35
+ text = re.sub(r'[–—]', '-', text)
36
+ text = re.sub(r'[«»„""]', '"', text)
37
+ return text
38
+
39
+ def encode_text(self, text: str) -> list[int]:
40
+ text = self.normalize_text(text)
41
+ return [self.char2id[ch] for ch in text if ch in self.char2id]
42
+
43
+ def build_encoder_input(self, text: str, speaker_id: int = 0) -> np.ndarray:
44
+ """Encoder input: <sot> text_chars <eot> <spk_X>"""
45
+ text_ids = self.encode_text(text)
46
+ spk = SPECIAL_TOKENS[f"<spk_{speaker_id}>"]
47
+ seq = [START_OF_TEXT_TOKEN_ID] + text_ids + [END_OF_TEXT_TOKEN_ID, spk]
48
+ return np.array(seq, dtype=np.int32)