ssasio beleata74 commited on
Commit
7eecd1a
·
0 Parent(s):

Duplicate from beleata74/Ani-Voice-API

Browse files

Co-authored-by: none <beleata74@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo1_conversation.wav filter=lfs diff=lfs merge=lfs -text
37
+ demo2_numbers.wav filter=lfs diff=lfs merge=lfs -text
38
+ demo3_expressive.wav filter=lfs diff=lfs merge=lfs -text
BgTTS/.gitattributes ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ samples/sample_female_bg1.wav filter=lfs diff=lfs merge=lfs -text
37
+ samples/sample_female_bg2.wav filter=lfs diff=lfs merge=lfs -text
38
+ samples/sample_female_bg3.wav filter=lfs diff=lfs merge=lfs -text
39
+ samples/sample_female_en1.wav filter=lfs diff=lfs merge=lfs -text
40
+ samples/sample_female_en2.wav filter=lfs diff=lfs merge=lfs -text
41
+ samples/sample_female_en3.wav filter=lfs diff=lfs merge=lfs -text
42
+ samples/sample_male2_bg1.wav filter=lfs diff=lfs merge=lfs -text
43
+ samples/sample_male2_bg2.wav filter=lfs diff=lfs merge=lfs -text
44
+ samples/sample_male2_bg3.wav filter=lfs diff=lfs merge=lfs -text
45
+ samples/sample_male2_en1.wav filter=lfs diff=lfs merge=lfs -text
46
+ samples/sample_male2_en2.wav filter=lfs diff=lfs merge=lfs -text
47
+ samples/sample_male2_en3.wav filter=lfs diff=lfs merge=lfs -text
48
+ samples/sample_male_bg1.wav filter=lfs diff=lfs merge=lfs -text
49
+ samples/sample_male_bg2.wav filter=lfs diff=lfs merge=lfs -text
50
+ samples/sample_male_bg3.wav filter=lfs diff=lfs merge=lfs -text
51
+ samples/sample_male_en1.wav filter=lfs diff=lfs merge=lfs -text
52
+ samples/sample_male_en2.wav filter=lfs diff=lfs merge=lfs -text
53
+ samples/sample_male_en3.wav filter=lfs diff=lfs merge=lfs -text
BgTTS/README.md ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - bg
5
+ - en
6
+ pipeline_tag: text-to-speech
7
+ tags:
8
+ - tts
9
+ - bulgarian
10
+ - miocodec
11
+ - encoder-decoder
12
+ - voice-cloning
13
+ - speech-synthesis
14
+ library_name: pytorch
15
+ ---
16
+
17
+ # BgTTS-38M V2 — Bulgarian Text-to-Speech with Voice Cloning
18
+
19
+ A lightweight **38M parameter** encoder-decoder TTS model for **Bulgarian and English** speech synthesis with **zero-shot voice cloning** via [MioCodec](https://huggingface.co/Aratako/MioCodec-25Hz-24kHz).
20
+
21
+ **V2 improvements over V1:**
22
+ - **Speaker normalization** — stable voice quality across all reference audio files
23
+ - **Larger training dataset** — 1,537 hours (vs 1,172h in V1)
24
+ - **BF16 training** — more stable gradients, no GradScaler needed
25
+ - **Zero dropout** — better utilization of model capacity
26
+ - **20 epochs** with careful LR scheduling
27
+
28
+ ## Audio Samples
29
+
30
+ ### Female Voice (Bulgarian)
31
+
32
+ <audio controls src="https://huggingface.co/beleata74/BgTTS-38M-V2/resolve/main/samples/sample_female_bg1.wav"></audio>
33
+
34
+ ### Female Voice (English)
35
+
36
+ <audio controls src="https://huggingface.co/beleata74/BgTTS-38M-V2/resolve/main/samples/sample_female_en1.wav"></audio>
37
+
38
+ ### Male Voice 1 (Bulgarian)
39
+
40
+ <audio controls src="https://huggingface.co/beleata74/BgTTS-38M-V2/resolve/main/samples/sample_male_bg1.wav"></audio>
41
+
42
+ ### Male Voice 1 (English)
43
+
44
+ <audio controls src="https://huggingface.co/beleata74/BgTTS-38M-V2/resolve/main/samples/sample_male_en1.wav"></audio>
45
+
46
+ ### Male Voice 2 (Bulgarian)
47
+
48
+ <audio controls src="https://huggingface.co/beleata74/BgTTS-38M-V2/resolve/main/samples/sample_male2_bg1.wav"></audio>
49
+
50
+ ### Male Voice 2 (English)
51
+
52
+ <audio controls src="https://huggingface.co/beleata74/BgTTS-38M-V2/resolve/main/samples/sample_male2_en1.wav"></audio>
53
+
54
+ ## Key Features
55
+
56
+ - **Bilingual**: Native Bulgarian + English in a single model
57
+ - **Voice cloning**: Zero-shot — just provide 3-10 seconds of reference audio
58
+ - **Tiny footprint**: 146 MB inference checkpoint, runs on CPU
59
+ - **Fast**: RTF ~0.3 on both GPU and CPU (3.3× faster than real-time)
60
+ - **Speaker-stable**: V2's normalized speaker embedding ensures consistent quality regardless of reference audio
61
+
62
+ ## 🎙️ Voice Cloning
63
+
64
+ This model supports zero-shot voice cloning — it can generate speech in any voice given just a short reference audio clip. No fine-tuning needed.
65
+
66
+ ### How it Works
67
+
68
+ 1. Provide a reference audio (3-10 seconds of clear speech, WAV format, ideally 24kHz)
69
+ 2. MioCodec extracts a 128-dimensional speaker embedding (`global_embedding`)
70
+ 3. The embedding is **L2-normalized** and scaled by a learned parameter (`spk_scale`) before being added to the decoder
71
+ 4. The same embedding is used for MioCodec waveform reconstruction
72
+
73
+ ### V2 Improvement: Speaker Normalization
74
+
75
+ In V1, the speaker embedding had 7× larger norm than content tokens, causing the model to over-rely on the reference audio for pronunciation quality. V2 normalizes the speaker vector to unit norm, ensuring:
76
+ - **Consistent quality** across all reference voices
77
+ - The model learns speech patterns from data, not from speaker shortcuts
78
+ - Reference audio only affects **timbre**, not articulation
79
+
80
+ ## Model Architecture
81
+
82
+ | Component | Details |
83
+ |---|---|
84
+ | Text Encoder | 4-layer bidirectional Transformer (d=384, 6 heads, ff=1536) |
85
+ | Audio Decoder | 8-layer causal Transformer (d=384, 6 heads, ff=1536) with cross-attention |
86
+ | Speaker Injection | L2-normalized Linear(128 → 384) with learned scale, additive bias |
87
+ | Audio Codec | [MioCodec](https://huggingface.co/Aratako/MioCodec-25Hz-24kHz) 25Hz, 1 codebook, 12800 codes, 24kHz output |
88
+ | Total Parameters | 38.2M (Encoder: 9.6M, Decoder: 28.6M) |
89
+ | Activations | SwiGLU |
90
+ | Normalization | RMSNorm (pre-norm) |
91
+ | Positional Encoding | Learned (encoder), RoPE (decoder) |
92
+ | Embeddings | Tied decoder (lm_head = token_embedding) |
93
+ | KV-Cache | Yes (for fast autoregressive inference) |
94
+
95
+ ### Tokenizer
96
+
97
+ Character-level tokenizer supporting 146 characters:
98
+ - Bulgarian Cyrillic (А-Я, а-я)
99
+ - English Latin (A-Z, a-z)
100
+ - Digits, punctuation, whitespace
101
+
102
+ Total vocabulary: **12,955 tokens** (9 special + 146 text + 12,800 audio codes)
103
+
104
+ ## Training
105
+
106
+ | Parameter | Value |
107
+ |---|---|
108
+ | **Data** | 728K samples, **1,537 hours** total |
109
+ | Bulgarian | ~620K samples (~1,368 hours) |
110
+ | English | ~108K samples (~169 hours) |
111
+ | **Epochs** | 20 |
112
+ | **LR Schedule** | Cosine decay, peak 7e-5, warmup 2 epochs, min 5e-6 |
113
+ | **Batch Size** | 64 |
114
+ | **Optimizer** | AdamW (betas=0.9, 0.999), weight decay 0.01 |
115
+ | **Precision** | BF16 (no GradScaler) |
116
+ | **Dropout** | 0.0 (unnecessary — model is 38M, data is 1,537h) |
117
+ | **Final Loss** | 5.04 |
118
+ | **Hardware** | NVIDIA RTX 5090 (32GB VRAM) |
119
+
120
+ ### Why Zero Dropout?
121
+
122
+ With only 38M parameters and 138M audio tokens (1,537 hours), the model has **0.28 parameters per token**. Overfitting is mathematically impossible — the model is severely underfitting the data. Dropout only slows convergence without providing any regularization benefit.
123
+
124
+ ## Quick Start
125
+
126
+ ### Requirements
127
+
128
+ ```bash
129
+ pip install torch torchaudio soundfile miocodec
130
+ ```
131
+
132
+ ### Python API
133
+
134
+ ```python
135
+ import torch
136
+ from model import load_for_inference
137
+ from tokenizer import TTSTokenizer
138
+ from codec import CodecV6
139
+ from inference import generate
140
+
141
+ device = "cuda" # or "cpu"
142
+
143
+ # Load model
144
+ model = load_for_inference("checkpoint_inference.pt", device=device)
145
+ tokenizer = TTSTokenizer()
146
+ codec = CodecV6(device=device)
147
+
148
+ # Get speaker embedding from reference audio
149
+ ref = codec.encode("reference_speaker.wav")
150
+ speaker_emb = ref["global_embedding"].to(device)
151
+
152
+ # Generate
153
+ codes = generate(
154
+ model, tokenizer,
155
+ text="Здравейте, как сте днес?",
156
+ speaker_emb=speaker_emb,
157
+ temperature=0.3,
158
+ top_k=250,
159
+ max_new_tokens=512,
160
+ device=device,
161
+ )
162
+
163
+ # Decode to audio
164
+ if codes is not None:
165
+ wav = codec.tokens_to_wav(codes, speaker_emb, "output.wav")
166
+ ```
167
+
168
+ ### CLI
169
+
170
+ ```bash
171
+ python inference.py \
172
+ --checkpoint checkpoint_inference.pt \
173
+ --text "Здравейте, как сте днес?" \
174
+ --speaker-wav reference.wav \
175
+ --output output.wav \
176
+ --temperature 0.3
177
+ ```
178
+
179
+ ### Web UI (Gradio)
180
+
181
+ ```bash
182
+ python server.py
183
+ # Opens at http://localhost:7860
184
+ ```
185
+
186
+ ### Parameters
187
+
188
+ | Parameter | Default | Description |
189
+ |---|---|---|
190
+ | `--temperature` | 0.3 | Sampling temperature (lower = stable, higher = expressive) |
191
+ | `--top-k` | 250 | Top-k filtering |
192
+ | `--top-p` | 0.95 | Nucleus sampling threshold |
193
+ | `--rep-penalty` | 1.1 | Repetition penalty on recent tokens |
194
+ | `--max-tokens` | 512 | Maximum decoder steps (~20 seconds) |
195
+
196
+ **Recommended temperature: 0.3** for clean, stable output. Use 0.5-0.7 for more expressive/varied speech.
197
+
198
+ ## ⚠️ Important: Sentence Length
199
+
200
+ > The encoder supports up to **256 characters** (~18 seconds of audio). For longer texts, `inference.py` automatically splits by sentence boundaries and concatenates the audio. No manual splitting needed.
201
+
202
+ ## Files
203
+
204
+ ```
205
+ checkpoint_inference.pt # Model weights only (146 MB)
206
+ checkpoint.pt # Full checkpoint with optimizer state (438 MB, for continued training)
207
+ config.py # Model configuration
208
+ model.py # Architecture (TTSEncoderDecoder + speaker normalization)
209
+ tokenizer.py # Character-level tokenizer
210
+ codec.py # MioCodec wrapper
211
+ inference.py # Inference pipeline with KV-cache + sentence splitting
212
+ train.py # Training script (BF16)
213
+ server.py # Gradio web UI
214
+ samples/ # Audio samples (3 voices × 2 languages × 3 texts)
215
+ ```
216
+
217
+ ## Performance
218
+
219
+ ### Benchmarks
220
+
221
+ | Hardware | RTF | Speed | Notes |
222
+ |---|---|---|---|
223
+ | **Intel i3-9100F (CPU)** | **0.30** | **3.3× real-time** | **Windows 10, CPU-only, no GPU** |
224
+
225
+ ### CPU-only Deployment (Tested on Windows 10)
226
+
227
+ | Component | Disk Space |
228
+ |---|---|
229
+ | Python venv (PyTorch CPU + deps) | 654 MB |
230
+ | BgTTS-38M-V2 (checkpoint + code) | 146 MB |
231
+ | MioCodec (auto-downloaded, cached) | 499 MB |
232
+ | WavLM base+ (auto-downloaded, cached) | 872 MB |
233
+ | **Total** | **2.12 GB** |
234
+
235
+ No NVIDIA GPU, no CUDA, no special drivers needed. Works on any x86-64 machine with Python 3.8+.
236
+
237
+ ## Comparison with Other Models
238
+
239
+ | Model | Parameters | Size | Languages | Voice Cloning | Open Source |
240
+ |---|---|---|---|---|---|
241
+ | **BgTTS-38M V2** | **38M** | **146 MB** | BG + EN | ✅ | ✅ |
242
+ | Kokoro-82M | 82M | ~200 MB | Multi | ❌ | ✅ |
243
+ | XTTS-v2 | ~467M | ~1.8 GB | 16 | ✅ | ✅ |
244
+ | CSM-1B | 1B | ~4 GB | EN | ✅ | ✅ |
245
+ | Dia-1.6B | 1.6B | ~6.4 GB | EN | ✅ | ✅ |
246
+
247
+ BgTTS-38M V2 is the **smallest TTS model with voice cloning** we are aware of, and the **only** open-source TTS model with native Bulgarian language support.
248
+
249
+ ## Limitations
250
+
251
+ - Best with sentences up to ~18 seconds. Longer texts are auto-split by `inference.py`.
252
+ - Bulgarian quality is superior to English (82% of training data is Bulgarian).
253
+ - Voice cloning quality depends on reference audio clarity — use clean recordings without background noise.
254
+ - No explicit prosody control (pitch, speed) — these are implicitly learned from data.
255
+ - Character-level tokenizer may struggle with rare Unicode characters outside the supported set.
256
+
257
+ ## License
258
+
259
+ Apache 2.0
260
+
261
+ ## Citation
262
+
263
+ ```bibtex
264
+ @misc{bgtts38mv2,
265
+ title={BgTTS-38M V2: Bulgarian Text-to-Speech with Voice Cloning and Speaker Normalization},
266
+ author={beleata74},
267
+ year={2026},
268
+ url={https://huggingface.co/beleata74/BgTTS-38M-V2}
269
+ }
270
+ ```
BgTTS/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """BG-TTS V6 — Encoder-Decoder with MioCodec + Speaker Embedding"""
BgTTS/checkpoint_inference.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b072815b1b915f2df60dc38d83bd9d524e9f67b76b64b91c36521dd59045a8ef
3
+ size 152965750
BgTTS/codec.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V6 Codec — MioCodec 25Hz wrapper
3
+ ==================================
4
+ Single codebook, 12800 codes, 25fps, 24kHz.
5
+ Supports global_embedding for voice cloning.
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ import soundfile as sf
11
+ from pathlib import Path
12
+ from typing import Optional, Union
13
+
14
+ from config import (
15
+ CODEC_MODEL_NAME, CODEC_SAMPLE_RATE,
16
+ CODEC_CODEBOOK_SIZE, CODEC_FRAME_RATE,
17
+ )
18
+
19
+
20
+ class CodecV6:
21
+ def __init__(self, device: str = "cuda"):
22
+ self.device = device
23
+ self.sample_rate = CODEC_SAMPLE_RATE # 24000
24
+ self.codebook_size = CODEC_CODEBOOK_SIZE # 12800
25
+ self.frame_rate = CODEC_FRAME_RATE # 25.0
26
+ self._load_model()
27
+
28
+ def _load_model(self):
29
+ from miocodec import MioCodecModel
30
+ self.model = MioCodecModel.from_pretrained(CODEC_MODEL_NAME)
31
+ self.model = self.model.to(self.device).eval()
32
+ print(f"MioCodec loaded: {CODEC_MODEL_NAME}, {self.sample_rate}Hz, "
33
+ f"{self.frame_rate}fps, {self.codebook_size} codes")
34
+
35
+ @torch.no_grad()
36
+ def encode(self, wav_path: str | Path) -> dict:
37
+ """
38
+ Encode wav file → MioCodec codes + global_embedding.
39
+ """
40
+ data, sr = sf.read(str(wav_path), dtype='float32')
41
+ waveform = torch.from_numpy(data)
42
+ return self.encode_waveform(waveform, sr)
43
+
44
+ @torch.no_grad()
45
+ def encode_waveform(self, waveform: torch.Tensor, sr: int) -> dict:
46
+ """
47
+ Encode directly from waveform tensor.
48
+ waveform: [samples] or [channels, samples]
49
+ sr: int
50
+ """
51
+ if waveform.dim() == 2: # stereo
52
+ waveform = waveform.mean(1)
53
+ if waveform.dim() == 1:
54
+ waveform = waveform.unsqueeze(0) # [1, samples]
55
+
56
+ if sr != self.sample_rate:
57
+ import torchaudio
58
+ waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
59
+
60
+ audio = waveform.to(self.device).float()
61
+
62
+ # MioCodec encode returns (content_token_indices, global_embedding)
63
+ result = self.model.encode(audio)
64
+ codes = result.content_token_indices.squeeze().cpu() # [num_frames]
65
+ global_emb = result.global_embedding.squeeze().cpu() # [128]
66
+
67
+ return {
68
+ 'codes': codes,
69
+ 'global_embedding': global_emb,
70
+ }
71
+
72
+ @torch.no_grad()
73
+ def decode(self, codes: torch.Tensor,
74
+ global_embedding: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Decode MioCodec codes → waveform.
77
+
78
+ Args:
79
+ codes: [num_frames] — token indices in [0, 12799]
80
+ global_embedding: [128] — speaker embedding
81
+
82
+ Returns:
83
+ waveform: [samples] float32
84
+ """
85
+ codes = codes.to(self.device)
86
+ global_embedding = global_embedding.to(self.device)
87
+
88
+ # MioCodec expects flat tensors: codes [num_frames], emb [128]
89
+ if codes.dim() > 1:
90
+ codes = codes.squeeze()
91
+ if global_embedding.dim() > 1:
92
+ global_embedding = global_embedding.squeeze()
93
+
94
+ audio = self.model.decode(
95
+ global_embedding=global_embedding,
96
+ content_token_indices=codes,
97
+ )
98
+ return audio.squeeze().cpu().float()
99
+
100
+ def encode_to_tokens(self, wav_path: str) -> dict:
101
+ """Convenience: encode and return codes + embedding."""
102
+ return self.encode(wav_path)
103
+
104
+ def tokens_to_wav(self, codes: torch.Tensor,
105
+ global_embedding: torch.Tensor,
106
+ output: Optional[str] = None) -> torch.Tensor:
107
+ """Decode tokens to wav, optionally save."""
108
+ wav = self.decode(codes, global_embedding)
109
+ if output:
110
+ sf.write(output, wav.numpy(), self.sample_rate)
111
+ return wav
112
+
113
+ def get_stats(self, wav_path: str) -> dict:
114
+ """Get encoding stats for a wav file."""
115
+ result = self.encode(wav_path)
116
+ data, sr = sf.read(str(wav_path), dtype='float32')
117
+ dur = len(data) / sr if data.ndim == 1 else data.shape[0] / sr
118
+ n_tokens = len(result['codes'])
119
+ return {
120
+ "duration_sec": dur,
121
+ "num_tokens": n_tokens,
122
+ "tokens_per_sec": n_tokens / dur if dur > 0 else 0,
123
+ "global_emb_shape": tuple(result['global_embedding'].shape),
124
+ }
BgTTS/config.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V6 Config — Encoder-Decoder TTS with MioCodec + Speaker Embedding
3
+ ==================================================================
4
+ Vocab layout:
5
+ [0..8] = 9 special tokens
6
+ [9..154] = ~146 text chars (BG + EN + digits + punct)
7
+ [155..12954] = 12,800 audio tokens (MioCodec, 1 codebook)
8
+ Total = 12,955
9
+
10
+ Architecture:
11
+ Encoder: 4L bidirectional, d=384, 6 heads — text understanding
12
+ Decoder: 8L causal + cross-attention, d=384, 6 heads — audio generation
13
+ Speaker: 128-dim global_embedding → Linear(128, 384) → added to decoder
14
+
15
+ Key differences from V5:
16
+ - MioCodec (25fps, 1CB, 12800) instead of NanoCodec (12.5fps, 4CB, 16128)
17
+ - d=384 for both encoder and decoder (V5: enc=512, dec=768)
18
+ - 8 decoder layers (V5: 18)
19
+ - Speaker embedding injection (V5: discrete speaker tokens)
20
+ - max_text=256, max_audio=512 (V5: 512/2048)
21
+ - ~40M params (V5: 250M)
22
+ - Expected RTF ~0.15-0.25 (V5: 1.1)
23
+ """
24
+
25
+ # ── MioCodec 25Hz ──────────────────────────────────────────────
26
+ CODEC_MODEL_NAME = "Aratako/MioCodec-25Hz-24kHz"
27
+ CODEC_SAMPLE_RATE = 24_000
28
+ CODEC_NUM_CODEBOOKS = 1
29
+ CODEC_CODEBOOK_SIZE = 12_800
30
+ CODEC_FRAME_RATE = 25.0
31
+ CODEC_TOKENS_PER_SEC = 25 # 25fps × 1 codebook
32
+ TOKENS_PER_FRAME = 1
33
+ SPEAKER_EMB_DIM = 128 # MioCodec global_embedding dimension
34
+
35
+ # ── Character set (same as V5) ─────────────────────────────────
36
+ BG_LOWER = "абвгдежзийклмнопрстуфхцчшщъьюя"
37
+ BG_UPPER = "АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯ"
38
+ EN_LOWER = "abcdefghijklmnopqrstuvwxyz"
39
+ EN_UPPER = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
40
+ DIGITS = "0123456789"
41
+ PUNCT = '.,!?;:-–—…"\'()[]{}«»„"" '
42
+ EXTRA = "\n\t"
43
+
44
+ _ALL_CHARS: list[str] = []
45
+ _seen: set[str] = set()
46
+ for _src in [BG_LOWER, BG_UPPER, EN_LOWER, EN_UPPER, DIGITS, PUNCT, EXTRA]:
47
+ for _ch in _src:
48
+ if _ch not in _seen:
49
+ _ALL_CHARS.append(_ch)
50
+ _seen.add(_ch)
51
+
52
+ # ── Special tokens (indices 0..8) ──────────────────────────────
53
+ SPECIAL_TOKENS = {
54
+ "<pad>": 0,
55
+ "<start_of_text>": 1,
56
+ "<end_of_text>": 2,
57
+ "<start_of_speech>": 3,
58
+ "<end_of_speech>": 4,
59
+ "<spk_0>": 5, # kept for compatibility, but speaker embedding is primary
60
+ "<spk_1>": 6,
61
+ "<spk_2>": 7,
62
+ "<spk_3>": 8,
63
+ }
64
+ NUM_SPECIAL_TOKENS = len(SPECIAL_TOKENS) # 9
65
+
66
+ # ── Vocab offsets ───────────────────────────────────────────────
67
+ TEXT_CHARS = _ALL_CHARS
68
+ TEXT_VOCAB_SIZE = len(TEXT_CHARS) # ~146
69
+ TEXT_OFFSET = NUM_SPECIAL_TOKENS # 9
70
+ AUDIO_OFFSET = TEXT_OFFSET + TEXT_VOCAB_SIZE # 155
71
+ NUM_AUDIO_TOKENS = CODEC_CODEBOOK_SIZE # 12,800
72
+ TOTAL_VOCAB_SIZE = AUDIO_OFFSET + NUM_AUDIO_TOKENS # 12,955
73
+
74
+ # Encoder needs only text vocab; decoder needs full vocab
75
+ ENCODER_VOCAB_SIZE = AUDIO_OFFSET # 155 (special + text)
76
+ DECODER_VOCAB_SIZE = TOTAL_VOCAB_SIZE # 12,955 (full)
77
+
78
+ # ── Convenience IDs ─────────────────────────────────────────────
79
+ PAD_TOKEN_ID = SPECIAL_TOKENS["<pad>"]
80
+ START_OF_TEXT_TOKEN_ID = SPECIAL_TOKENS["<start_of_text>"]
81
+ END_OF_TEXT_TOKEN_ID = SPECIAL_TOKENS["<end_of_text>"]
82
+ START_OF_SPEECH_TOKEN_ID = SPECIAL_TOKENS["<start_of_speech>"]
83
+ END_OF_SPEECH_TOKEN_ID = SPECIAL_TOKENS["<end_of_speech>"]
84
+ SPK_0_TOKEN_ID = SPECIAL_TOKENS["<spk_0>"]
85
+ SPK_1_TOKEN_ID = SPECIAL_TOKENS["<spk_1>"]
86
+
87
+ # ── Helper functions ────────────────────────────────────────────
88
+ def audio_token_id(code: int) -> int:
89
+ """MioCodec code → global token ID."""
90
+ return AUDIO_OFFSET + code
91
+
92
+ def decode_audio_token(token_id: int) -> int:
93
+ """Global token ID → MioCodec code."""
94
+ return token_id - AUDIO_OFFSET
95
+
96
+ def is_audio_token(token_id: int) -> bool:
97
+ return AUDIO_OFFSET <= token_id < AUDIO_OFFSET + NUM_AUDIO_TOKENS
98
+
99
+ def is_special_token(token_id: int) -> bool:
100
+ return 0 <= token_id < NUM_SPECIAL_TOKENS
101
+
102
+ def is_text_token(token_id: int) -> bool:
103
+ return TEXT_OFFSET <= token_id < AUDIO_OFFSET
104
+
105
+ # ── V6 Model Config ────────────────────────────────────────────
106
+ # Encoder: 4 bidirectional layers
107
+ ENC_D_MODEL = 384
108
+ ENC_N_HEADS = 6
109
+ ENC_N_LAYERS = 4
110
+ ENC_D_FF = 1536
111
+
112
+ # Decoder: 8 causal layers with cross-attention
113
+ DEC_D_MODEL = 384
114
+ DEC_N_HEADS = 6
115
+ DEC_N_LAYERS = 8
116
+ DEC_D_FF = 1536
117
+
118
+ MAX_TEXT_LEN = 256 # Max text tokens (chars) — covers ~17s speech
119
+ MAX_AUDIO_LEN = 512 # Max audio tokens — 512/25 = 20.5s
120
+ DROPOUT = 0.0
121
+
122
+ # ── Training defaults ──────────────────────────────────────────
123
+ BATCH_SIZE = 16 # Smaller model = bigger batch
124
+ GRAD_ACCUM = 4 # effective = 64
125
+ LR = 3e-4
126
+ WEIGHT_DECAY = 0.1
127
+ WARMUP_STEPS = 1000
128
+ NUM_EPOCHS = 5
129
+
130
+ # ── Print summary ──────────────────────────────────────────────
131
+ if __name__ == "__main__":
132
+ print(f"V6 Vocab Layout:")
133
+ print(f" Special: [0, {NUM_SPECIAL_TOKENS-1}] ({NUM_SPECIAL_TOKENS} tokens)")
134
+ print(f" Text: [{TEXT_OFFSET}, {AUDIO_OFFSET-1}] ({TEXT_VOCAB_SIZE} chars)")
135
+ print(f" Audio: [{AUDIO_OFFSET}, {TOTAL_VOCAB_SIZE-1}] ({NUM_AUDIO_TOKENS} tokens)")
136
+ print(f" TOTAL: {TOTAL_VOCAB_SIZE}")
137
+ print()
138
+ print(f"V6 Encoder: d={ENC_D_MODEL}, heads={ENC_N_HEADS}, L={ENC_N_LAYERS}, ff={ENC_D_FF}")
139
+ print(f"V6 Decoder: d={DEC_D_MODEL}, heads={DEC_N_HEADS}, L={DEC_N_LAYERS}, ff={DEC_D_FF}")
140
+ print(f"V6 Codec: MioCodec {CODEC_FRAME_RATE}fps, {CODEC_NUM_CODEBOOKS}CB × {CODEC_CODEBOOK_SIZE}")
141
+ print(f"V6 Speaker: {SPEAKER_EMB_DIM}-dim global_embedding")
142
+ print(f"V6 Limits: max_text={MAX_TEXT_LEN}, max_audio={MAX_AUDIO_LEN}")
BgTTS/inference.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V6 Inference — encoder-decoder TTS with MioCodec + speaker cloning
3
+ ===================================================================
4
+ 1. Encode text with encoder (bidirectional, once)
5
+ 2. Autoregressively decode audio tokens with decoder + speaker embedding
6
+ 3. Decode tokens with MioCodec using global_embedding
7
+ """
8
+
9
+ import torch
10
+ import argparse
11
+ import time
12
+ from pathlib import Path
13
+ from config import (
14
+ AUDIO_OFFSET, NUM_AUDIO_TOKENS, END_OF_SPEECH_TOKEN_ID,
15
+ START_OF_SPEECH_TOKEN_ID, CODEC_SAMPLE_RATE, CODEC_FRAME_RATE,
16
+ )
17
+ from tokenizer import TTSTokenizer
18
+ from codec import CodecV6
19
+ from model import load_for_inference
20
+
21
+
22
+ def _split_text(text, tokenizer, max_len=250):
23
+ """Split text into chunks that fit within encoder max_text_len."""
24
+ import re
25
+ sentences = re.split(r'(?<=[.!?;:,])\s+', text)
26
+ chunks = []
27
+ current = ""
28
+ for sent in sentences:
29
+ candidate = (current + " " + sent).strip() if current else sent
30
+ enc_len = len(tokenizer.build_encoder_input(candidate))
31
+ if enc_len <= max_len:
32
+ current = candidate
33
+ else:
34
+ if current:
35
+ chunks.append(current)
36
+ # If single sentence is too long, split by words
37
+ if len(tokenizer.build_encoder_input(sent)) > max_len:
38
+ words = sent.split()
39
+ current = ""
40
+ for w in words:
41
+ cand = (current + " " + w).strip() if current else w
42
+ if len(tokenizer.build_encoder_input(cand)) <= max_len:
43
+ current = cand
44
+ else:
45
+ if current:
46
+ chunks.append(current)
47
+ current = w
48
+ else:
49
+ current = sent
50
+ if current:
51
+ chunks.append(current)
52
+ return chunks
53
+
54
+
55
+ @torch.no_grad()
56
+ def generate(model, tokenizer, text, speaker_emb,
57
+ max_new_tokens=512, temperature=0.7, top_k=250,
58
+ top_p=0.95, rep_penalty=1.1, device="cuda"):
59
+ """
60
+ Generate audio tokens from text.
61
+
62
+ Args:
63
+ model: TTSEncoderDecoder
64
+ tokenizer: TTSTokenizer
65
+ text: input text string
66
+ speaker_emb: [128] MioCodec global_embedding
67
+ max_new_tokens: max decoder steps
68
+ temperature: sampling temperature
69
+ top_k: top-k filtering
70
+ top_p: nucleus sampling threshold
71
+ rep_penalty: repetition penalty on recent tokens
72
+ device: cuda/cpu
73
+
74
+ Returns:
75
+ torch.Tensor of MioCodec codes [num_frames], or None
76
+ """
77
+ # 1. Encode text (one shot, bidirectional)
78
+ enc_ids = tokenizer.build_encoder_input(text).unsqueeze(0).to(device)
79
+ enc_mask = torch.ones_like(enc_ids)
80
+
81
+ enc_out = model.encode(enc_ids, enc_mask) # [1, T_enc, d_model]
82
+
83
+ # 2. Prepare speaker embedding
84
+ spk = speaker_emb.unsqueeze(0).to(device) # [1, 128]
85
+
86
+ # 3. Start decoder with <sos>
87
+ dec_ids = torch.tensor([[START_OF_SPEECH_TOKEN_ID]], device=device)
88
+ past = None
89
+ generated_tokens = []
90
+
91
+ for step in range(max_new_tokens):
92
+ inp = dec_ids[:, -1:] if past is not None else dec_ids
93
+
94
+ # Only pass speaker_emb on first step (already baked into embeddings)
95
+ # Actually, with KV-cache, we only process new tokens, so speaker
96
+ # needs to be added each time. The model handles this correctly.
97
+ dec_out = model.decoder(
98
+ input_ids=inp,
99
+ encoder_output=enc_out,
100
+ encoder_mask=enc_mask,
101
+ speaker_emb=spk,
102
+ past_key_values=past,
103
+ use_cache=True,
104
+ )
105
+ past = dec_out["past_key_values"]
106
+ logits = dec_out["logits"][:, -1, :]
107
+
108
+ # Mask: only allow audio tokens + end_of_speech
109
+ mask = torch.full_like(logits, float("-inf"))
110
+ mask[:, AUDIO_OFFSET:AUDIO_OFFSET + NUM_AUDIO_TOKENS] = 0
111
+ mask[:, END_OF_SPEECH_TOKEN_ID] = 0
112
+ logits = logits + mask
113
+
114
+ # Repetition penalty on recent tokens
115
+ if rep_penalty != 1.0 and generated_tokens:
116
+ recent = set(generated_tokens[-100:])
117
+ for tid in recent:
118
+ if AUDIO_OFFSET <= tid < AUDIO_OFFSET + NUM_AUDIO_TOKENS:
119
+ logits[:, tid] /= rep_penalty
120
+
121
+ logits = logits / temperature
122
+
123
+ # Top-k
124
+ if top_k > 0:
125
+ kth = torch.topk(logits, min(top_k, logits.shape[-1])).values[:, -1:]
126
+ logits[logits < kth] = float("-inf")
127
+
128
+ # Top-p (nucleus)
129
+ if top_p < 1.0:
130
+ sorted_l, sorted_i = torch.sort(logits, descending=True)
131
+ cum = torch.cumsum(torch.softmax(sorted_l, -1), -1)
132
+ remove = cum > top_p
133
+ remove[:, 1:] = remove[:, :-1].clone()
134
+ remove[:, 0] = False
135
+ logits[remove.scatter(1, sorted_i, remove)] = float("-inf")
136
+
137
+ next_tok = torch.multinomial(torch.softmax(logits, -1), 1)
138
+ tok_id = next_tok.item()
139
+
140
+ if tok_id == END_OF_SPEECH_TOKEN_ID:
141
+ break
142
+
143
+ generated_tokens.append(tok_id)
144
+ dec_ids = torch.cat([dec_ids, next_tok], dim=-1)
145
+
146
+ if not generated_tokens:
147
+ return None
148
+
149
+ result = torch.tensor(generated_tokens, dtype=torch.long)
150
+ audio_mask = (result >= AUDIO_OFFSET) & (result < AUDIO_OFFSET + NUM_AUDIO_TOKENS)
151
+ return result[audio_mask] - AUDIO_OFFSET
152
+
153
+
154
+ def synthesize(checkpoint, text, output="output.wav",
155
+ speaker_wav=None, speaker_emb_path=None,
156
+ temperature=0.7, top_k=250, top_p=0.95,
157
+ rep_penalty=1.1, max_tokens=512, device="cuda"):
158
+ """
159
+ Full TTS pipeline: text → audio file.
160
+
161
+ Speaker can be provided as:
162
+ 1. speaker_wav: path to reference audio (will encode with MioCodec)
163
+ 2. speaker_emb_path: path to saved .pt embedding
164
+ """
165
+ print(f"'{text[:80]}' | T={temperature}")
166
+ model = load_for_inference(checkpoint, device=device)
167
+ tokenizer = TTSTokenizer()
168
+ codec = CodecV6(device=device)
169
+
170
+ # Get speaker embedding
171
+ if speaker_emb_path:
172
+ import numpy as np
173
+ if speaker_emb_path.endswith('.npy'):
174
+ speaker_emb = torch.from_numpy(np.load(speaker_emb_path)).to(device)
175
+ else:
176
+ speaker_emb = torch.load(speaker_emb_path, map_location=device, weights_only=False)
177
+ if isinstance(speaker_emb, dict):
178
+ speaker_emb = speaker_emb.get("global_embedding",
179
+ speaker_emb.get("embedding"))
180
+ if speaker_emb.dim() > 1:
181
+ speaker_emb = speaker_emb.squeeze()
182
+ print(f"Speaker from preset: {speaker_emb.shape}")
183
+ elif speaker_wav:
184
+ result = codec.encode(speaker_wav)
185
+ speaker_emb = result['global_embedding'].to(device)
186
+ print(f"Speaker from wav: {speaker_wav}")
187
+ else:
188
+ raise ValueError("Provide speaker_wav or speaker_emb_path")
189
+
190
+ # Split long text into chunks that fit encoder max_text_len
191
+ chunks = _split_text(text, tokenizer, max_len=250)
192
+ print(f"Text split into {len(chunks)} chunk(s)")
193
+
194
+ t0 = time.time()
195
+ all_codes = []
196
+ for i, chunk in enumerate(chunks):
197
+ enc_len = len(tokenizer.build_encoder_input(chunk))
198
+ print(f" [{i+1}/{len(chunks)}] {enc_len} enc tokens: '{chunk[:60]}...'")
199
+ codes = generate(model, tokenizer, chunk, speaker_emb, max_tokens,
200
+ temperature, top_k, top_p, rep_penalty, device)
201
+ if codes is not None and len(codes) > 0:
202
+ all_codes.append(codes)
203
+ gen_time = time.time() - t0
204
+
205
+ if not all_codes:
206
+ print("No audio generated!")
207
+ return
208
+
209
+ codes = torch.cat(all_codes)
210
+ audio_dur = len(codes) / CODEC_FRAME_RATE
211
+ rtf = gen_time / audio_dur if audio_dur > 0 else float('inf')
212
+ print(f"{len(codes)} tokens ({audio_dur:.1f}s audio, {gen_time:.2f}s gen, RTF={rtf:.3f})")
213
+
214
+ # Decode to wav
215
+ wav = codec.tokens_to_wav(codes, speaker_emb, output)
216
+ print(f"Saved: {output} ({len(wav)/CODEC_SAMPLE_RATE:.2f}s)")
217
+ return wav
218
+
219
+
220
+ def main():
221
+ p = argparse.ArgumentParser(description="V6 TTS Inference")
222
+ p.add_argument("--checkpoint", required=True)
223
+ p.add_argument("--text", required=True)
224
+ p.add_argument("--output", default="output.wav")
225
+ p.add_argument("--speaker-wav", help="Reference audio for voice cloning")
226
+ p.add_argument("--speaker-emb", help="Path to saved speaker embedding .pt")
227
+ p.add_argument("--temperature", type=float, default=0.7)
228
+ p.add_argument("--top-k", type=int, default=250)
229
+ p.add_argument("--top-p", type=float, default=0.95)
230
+ p.add_argument("--rep-penalty", type=float, default=1.1)
231
+ p.add_argument("--max-tokens", type=int, default=512)
232
+ a = p.parse_args()
233
+ synthesize(a.checkpoint, a.text, a.output,
234
+ speaker_wav=a.speaker_wav,
235
+ speaker_emb_path=a.speaker_emb,
236
+ temperature=a.temperature, top_k=a.top_k,
237
+ top_p=a.top_p, rep_penalty=a.rep_penalty,
238
+ max_tokens=a.max_tokens)
239
+
240
+ if __name__ == "__main__":
241
+ main()
BgTTS/model.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V6 Model — Encoder-Decoder TTS with MioCodec + Speaker Embedding
3
+ =================================================================
4
+ Architecture (V6 Small):
5
+ - Text Encoder: 4-layer bidirectional Transformer (d=384, 6 heads, ff=1536)
6
+ Learned positional embeddings, RMSNorm, SwiGLU
7
+ - Audio Decoder: 8-layer causal Transformer (d=384, 6 heads, ff=1536)
8
+ RoPE, cross-attention to encoder at every layer, RMSNorm, SwiGLU
9
+ - Speaker Projection: Linear(128, 384) — MioCodec global_embedding → decoder dim
10
+
11
+ Key design:
12
+ - enc_d == dec_d == 384 → no projection layer needed
13
+ - Speaker embedding (128-dim) injected into decoder as additive bias
14
+ - Tied decoder embeddings (lm_head = token_embedding.weight)
15
+ - Gradient checkpointing in decoder during training
16
+ - KV-cache for inference
17
+ - ~38M params total
18
+
19
+ Target inference: RTF ~0.25-0.30 on RTX 5090
20
+ """
21
+
22
+ import math
23
+ import os
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from typing import Optional, Tuple, Dict
28
+ from dataclasses import dataclass
29
+
30
+ from config import (
31
+ TOTAL_VOCAB_SIZE, ENCODER_VOCAB_SIZE, DECODER_VOCAB_SIZE,
32
+ ENC_D_MODEL, ENC_N_HEADS, ENC_N_LAYERS, ENC_D_FF,
33
+ DEC_D_MODEL, DEC_N_HEADS, DEC_N_LAYERS, DEC_D_FF,
34
+ MAX_TEXT_LEN, MAX_AUDIO_LEN, DROPOUT,
35
+ PAD_TOKEN_ID, NUM_AUDIO_TOKENS, AUDIO_OFFSET,
36
+ SPEAKER_EMB_DIM,
37
+ )
38
+
39
+
40
+ # ── Shared Components ──────────────────────────────────────────
41
+
42
+ class RMSNorm(nn.Module):
43
+ def __init__(self, dim: int, eps: float = 1e-6):
44
+ super().__init__()
45
+ self.eps = eps
46
+ self.weight = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
50
+
51
+
52
+ class RotaryPositionalEmbedding(nn.Module):
53
+ def __init__(self, dim: int, max_seq_len: int = 4096, base: float = 10000.0):
54
+ super().__init__()
55
+ self.dim = dim
56
+ self.max_seq_len = max_seq_len
57
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
58
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
59
+ self._build_cache(max_seq_len)
60
+
61
+ def _build_cache(self, seq_len: int):
62
+ t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
63
+ freqs = torch.outer(t, self.inv_freq)
64
+ emb = torch.cat((freqs, freqs), dim=-1)
65
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
66
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
67
+
68
+ def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
69
+ if seq_len > self.max_seq_len:
70
+ self._build_cache(seq_len)
71
+ self.max_seq_len = seq_len
72
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
73
+
74
+
75
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
76
+ x1, x2 = x.chunk(2, dim=-1)
77
+ return torch.cat((-x2, x1), dim=-1)
78
+
79
+
80
+ def apply_rotary_pos_emb(q, k, cos, sin):
81
+ cos = cos.unsqueeze(0).unsqueeze(0)
82
+ sin = sin.unsqueeze(0).unsqueeze(0)
83
+ return (q * cos + rotate_half(q) * sin,
84
+ k * cos + rotate_half(k) * sin)
85
+
86
+
87
+ class SwiGLUFFN(nn.Module):
88
+ def __init__(self, d_model: int, d_ff: int, dropout: float):
89
+ super().__init__()
90
+ self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
91
+ self.up_proj = nn.Linear(d_model, d_ff, bias=False)
92
+ self.down_proj = nn.Linear(d_ff, d_model, bias=False)
93
+ self.dropout = nn.Dropout(dropout)
94
+
95
+ def forward(self, x):
96
+ return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
97
+
98
+
99
+ # ── Encoder (Bidirectional) ────────────────────────────────────
100
+
101
+ class EncoderSelfAttention(nn.Module):
102
+ """Bidirectional self-attention for text encoder (NO causal mask)."""
103
+ def __init__(self, d_model: int, n_heads: int, dropout: float):
104
+ super().__init__()
105
+ self.d_model = d_model
106
+ self.n_heads = n_heads
107
+ self.head_dim = d_model // n_heads
108
+ assert d_model % n_heads == 0
109
+
110
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
111
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
112
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
113
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
114
+ self.resid_dropout = nn.Dropout(dropout)
115
+
116
+ def forward(self, x, key_padding_mask=None):
117
+ B, T, _ = x.shape
118
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
119
+ k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
120
+ v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
121
+
122
+ attn_mask = None
123
+ if key_padding_mask is not None:
124
+ attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, T]
125
+ attn_mask = attn_mask.float() * torch.finfo(q.dtype).min
126
+
127
+ attn_out = F.scaled_dot_product_attention(
128
+ q, k, v,
129
+ attn_mask=attn_mask,
130
+ dropout_p=self.resid_dropout.p if self.training else 0.0,
131
+ is_causal=False,
132
+ )
133
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
134
+ return self.resid_dropout(self.o_proj(attn_out))
135
+
136
+
137
+ class EncoderBlock(nn.Module):
138
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float):
139
+ super().__init__()
140
+ self.attn_norm = RMSNorm(d_model)
141
+ self.attention = EncoderSelfAttention(d_model, n_heads, dropout)
142
+ self.ffn_norm = RMSNorm(d_model)
143
+ self.ffn = SwiGLUFFN(d_model, d_ff, dropout)
144
+
145
+ def forward(self, x, key_padding_mask=None):
146
+ x = x + self.attention(self.attn_norm(x), key_padding_mask)
147
+ x = x + self.ffn(self.ffn_norm(x))
148
+ return x
149
+
150
+
151
+ class TextEncoder(nn.Module):
152
+ """
153
+ Bidirectional Transformer encoder for text.
154
+ Input: text token IDs (special + chars, vocab 155)
155
+ Output: contextualized text representations [B, T_text, d_model]
156
+ """
157
+ def __init__(self, vocab_size=ENCODER_VOCAB_SIZE, d_model=ENC_D_MODEL,
158
+ n_heads=ENC_N_HEADS, n_layers=ENC_N_LAYERS, d_ff=ENC_D_FF,
159
+ max_len=MAX_TEXT_LEN, dropout=DROPOUT):
160
+ super().__init__()
161
+ self.d_model = d_model
162
+ self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD_TOKEN_ID)
163
+ self.pos_embedding = nn.Embedding(max_len, d_model)
164
+ self.embed_dropout = nn.Dropout(dropout)
165
+
166
+ self.layers = nn.ModuleList([
167
+ EncoderBlock(d_model, n_heads, d_ff, dropout)
168
+ for _ in range(n_layers)
169
+ ])
170
+ self.final_norm = RMSNorm(d_model)
171
+
172
+ def forward(self, input_ids, attention_mask=None):
173
+ B, T = input_ids.shape
174
+ pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
175
+ h = self.embed_dropout(self.token_embedding(input_ids) + self.pos_embedding(pos))
176
+
177
+ key_padding_mask = None
178
+ if attention_mask is not None:
179
+ key_padding_mask = (attention_mask == 0)
180
+
181
+ for layer in self.layers:
182
+ h = layer(h, key_padding_mask)
183
+
184
+ return self.final_norm(h)
185
+
186
+
187
+ # ── Decoder (Causal with Cross-Attention + Speaker) ────────────
188
+
189
+ class DecoderSelfAttention(nn.Module):
190
+ """Causal self-attention with RoPE and KV-cache."""
191
+ def __init__(self, d_model: int, n_heads: int, dropout: float, max_len: int):
192
+ super().__init__()
193
+ self.d_model = d_model
194
+ self.n_heads = n_heads
195
+ self.head_dim = d_model // n_heads
196
+ assert d_model % n_heads == 0
197
+
198
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
199
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
200
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
201
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
202
+ self.resid_dropout = nn.Dropout(dropout)
203
+ self.rope = RotaryPositionalEmbedding(self.head_dim, max_len)
204
+
205
+ def forward(self, x, past_kv=None, use_cache=False):
206
+ B, T, _ = x.shape
207
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
208
+ k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
209
+ v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
210
+
211
+ # RoPE
212
+ if past_kv is not None:
213
+ offset = past_kv[0].shape[2]
214
+ cos, sin = self.rope(offset + T)
215
+ cos, sin = cos[offset:offset + T], sin[offset:offset + T]
216
+ else:
217
+ cos, sin = self.rope(T)
218
+
219
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
220
+
221
+ if past_kv is not None:
222
+ k = torch.cat([past_kv[0], k], dim=2)
223
+ v = torch.cat([past_kv[1], v], dim=2)
224
+
225
+ new_kv = (k, v) if use_cache else None
226
+
227
+ is_causal = (past_kv is None) and (T > 1)
228
+ attn_out = F.scaled_dot_product_attention(
229
+ q, k, v,
230
+ dropout_p=self.resid_dropout.p if self.training else 0.0,
231
+ is_causal=is_causal,
232
+ )
233
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
234
+ return self.resid_dropout(self.o_proj(attn_out)), new_kv
235
+
236
+
237
+ class CrossAttention(nn.Module):
238
+ """Cross-attention: decoder queries attend to encoder keys/values."""
239
+ def __init__(self, d_model: int, n_heads: int, dropout: float):
240
+ super().__init__()
241
+ self.d_model = d_model
242
+ self.n_heads = n_heads
243
+ self.head_dim = d_model // n_heads
244
+ assert d_model % n_heads == 0
245
+
246
+ # Q from decoder, K/V from encoder — same dim since enc_d == dec_d
247
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
248
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
249
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
250
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
251
+ self.resid_dropout = nn.Dropout(dropout)
252
+
253
+ def forward(self, x, encoder_output, encoder_mask=None, cached_kv=None, use_cache=False):
254
+ B, T, _ = x.shape
255
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
256
+
257
+ if cached_kv is not None:
258
+ k, v = cached_kv
259
+ else:
260
+ T_enc = encoder_output.shape[1]
261
+ k = self.k_proj(encoder_output).view(B, T_enc, self.n_heads, self.head_dim).transpose(1, 2)
262
+ v = self.v_proj(encoder_output).view(B, T_enc, self.n_heads, self.head_dim).transpose(1, 2)
263
+
264
+ new_kv = (k, v) if use_cache else None
265
+
266
+ attn_mask = None
267
+ if encoder_mask is not None:
268
+ attn_mask = (encoder_mask == 0).unsqueeze(1).unsqueeze(2)
269
+ attn_mask = attn_mask.float() * torch.finfo(q.dtype).min
270
+
271
+ attn_out = F.scaled_dot_product_attention(
272
+ q, k, v,
273
+ attn_mask=attn_mask,
274
+ dropout_p=self.resid_dropout.p if self.training else 0.0,
275
+ is_causal=False,
276
+ )
277
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
278
+ return self.resid_dropout(self.o_proj(attn_out)), new_kv
279
+
280
+
281
+ class DecoderBlock(nn.Module):
282
+ """Decoder block: self-attention → cross-attention → FFN"""
283
+ def __init__(self, d_model: int, n_heads: int, d_ff: int,
284
+ dropout: float, max_len: int):
285
+ super().__init__()
286
+ self.self_attn_norm = RMSNorm(d_model)
287
+ self.self_attention = DecoderSelfAttention(d_model, n_heads, dropout, max_len)
288
+
289
+ self.cross_attn_norm = RMSNorm(d_model)
290
+ self.cross_attention = CrossAttention(d_model, n_heads, dropout)
291
+
292
+ self.ffn_norm = RMSNorm(d_model)
293
+ self.ffn = SwiGLUFFN(d_model, d_ff, dropout)
294
+
295
+ def forward(self, x, encoder_output, encoder_mask=None,
296
+ past_self_kv=None, past_cross_kv=None, use_cache=False):
297
+ # 1. Causal self-attention
298
+ h = self.self_attn_norm(x)
299
+ attn_out, new_self_kv = self.self_attention(h, past_self_kv, use_cache)
300
+ x = x + attn_out
301
+
302
+ # 2. Cross-attention to encoder
303
+ h = self.cross_attn_norm(x)
304
+ cross_out, new_cross_kv = self.cross_attention(
305
+ h, encoder_output, encoder_mask, past_cross_kv, use_cache)
306
+ x = x + cross_out
307
+
308
+ # 3. FFN
309
+ x = x + self.ffn(self.ffn_norm(x))
310
+
311
+ return x, new_self_kv, new_cross_kv
312
+
313
+
314
+ class AudioDecoder(nn.Module):
315
+ """
316
+ Causal Transformer decoder with cross-attention + speaker embedding.
317
+ Speaker embedding is added once to the token embeddings (like a global bias).
318
+ """
319
+ def __init__(self, vocab_size=DECODER_VOCAB_SIZE, d_model=DEC_D_MODEL,
320
+ n_heads=DEC_N_HEADS, n_layers=DEC_N_LAYERS, d_ff=DEC_D_FF,
321
+ max_len=MAX_AUDIO_LEN, dropout=DROPOUT,
322
+ speaker_emb_dim=SPEAKER_EMB_DIM):
323
+ super().__init__()
324
+ self.config_d_model = d_model
325
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
326
+ self.embed_dropout = nn.Dropout(dropout)
327
+
328
+ # Speaker embedding projection: 128 → d_model (normalized)
329
+ self.speaker_proj = nn.Linear(speaker_emb_dim, d_model, bias=False)
330
+ self.register_buffer('spk_scale', torch.ones(1)) # fixed scale, not learnable
331
+
332
+ self.layers = nn.ModuleList([
333
+ DecoderBlock(d_model, n_heads, d_ff, dropout, max_len)
334
+ for _ in range(n_layers)
335
+ ])
336
+ self.final_norm = RMSNorm(d_model)
337
+
338
+ # LM head — tied with token embedding
339
+ self.lm_head = None # tied
340
+
341
+ def forward(self, input_ids, encoder_output, encoder_mask=None,
342
+ speaker_emb=None, labels=None,
343
+ past_key_values=None, use_cache=False):
344
+ """
345
+ input_ids: [B, T_dec]
346
+ encoder_output: [B, T_enc, d_model]
347
+ encoder_mask: [B, T_enc]
348
+ speaker_emb: [B, 128] — MioCodec global_embedding
349
+ labels: [B, T_dec] — for training
350
+ """
351
+ h = self.token_embedding(input_ids)
352
+
353
+ # Inject speaker embedding — normalized, additive, broadcast over time
354
+ if speaker_emb is not None:
355
+ spk = self.speaker_proj(speaker_emb) # [B, d_model]
356
+ spk = F.normalize(spk, dim=-1) * self.spk_scale # normalize to unit norm
357
+ h = h + spk.unsqueeze(1) # [B, 1, d_model] broadcast
358
+
359
+ h = self.embed_dropout(h)
360
+
361
+ new_kvs = [] if use_cache else None
362
+ for i, layer in enumerate(self.layers):
363
+ past_self_kv = past_key_values[i][0] if past_key_values else None
364
+ past_cross_kv = past_key_values[i][1] if past_key_values else None
365
+
366
+ if self.training and not use_cache:
367
+ h, self_kv, cross_kv = torch.utils.checkpoint.checkpoint(
368
+ layer, h, encoder_output, encoder_mask,
369
+ past_self_kv, past_cross_kv, use_cache,
370
+ use_reentrant=False)
371
+ else:
372
+ h, self_kv, cross_kv = layer(
373
+ h, encoder_output, encoder_mask,
374
+ past_self_kv, past_cross_kv, use_cache)
375
+
376
+ if use_cache:
377
+ new_kvs.append((self_kv, cross_kv))
378
+
379
+ h = self.final_norm(h)
380
+
381
+ # Tied embeddings
382
+ logits = F.linear(h, self.token_embedding.weight)
383
+
384
+ result = {"logits": logits}
385
+ if use_cache:
386
+ result["past_key_values"] = new_kvs
387
+
388
+ if labels is not None:
389
+ shift_logits = logits[:, :-1, :].contiguous()
390
+ shift_labels = labels[:, 1:].contiguous()
391
+ loss = F.cross_entropy(
392
+ shift_logits.view(-1, shift_logits.size(-1)),
393
+ shift_labels.view(-1),
394
+ ignore_index=-100,
395
+ )
396
+ result["loss"] = loss
397
+
398
+ return result
399
+
400
+
401
+ # ── Full Encoder-Decoder Model ─────────────────────────────────
402
+
403
+ @dataclass
404
+ class V6Config:
405
+ # Encoder
406
+ enc_vocab_size: int = ENCODER_VOCAB_SIZE
407
+ enc_d_model: int = ENC_D_MODEL
408
+ enc_n_heads: int = ENC_N_HEADS
409
+ enc_n_layers: int = ENC_N_LAYERS
410
+ enc_d_ff: int = ENC_D_FF
411
+ max_text_len: int = MAX_TEXT_LEN
412
+ # Decoder
413
+ dec_vocab_size: int = DECODER_VOCAB_SIZE
414
+ dec_d_model: int = DEC_D_MODEL
415
+ dec_n_heads: int = DEC_N_HEADS
416
+ dec_n_layers: int = DEC_N_LAYERS
417
+ dec_d_ff: int = DEC_D_FF
418
+ max_audio_len: int = MAX_AUDIO_LEN
419
+ # Speaker
420
+ speaker_emb_dim: int = SPEAKER_EMB_DIM
421
+ # Shared
422
+ dropout: float = DROPOUT
423
+
424
+
425
+ class TTSEncoderDecoder(nn.Module):
426
+ """
427
+ V6 Encoder-Decoder TTS with MioCodec + Speaker Embedding.
428
+
429
+ Forward flow:
430
+ 1. Text → Encoder → contextualized text representations [B, T_text, d_model]
431
+ 2. Audio tokens + speaker_emb → Decoder (with cross-attn) → logits
432
+ """
433
+ def __init__(self, config: V6Config):
434
+ super().__init__()
435
+ self.config = config
436
+
437
+ # Text encoder (bidirectional)
438
+ self.encoder = TextEncoder(
439
+ vocab_size=config.enc_vocab_size,
440
+ d_model=config.enc_d_model,
441
+ n_heads=config.enc_n_heads,
442
+ n_layers=config.enc_n_layers,
443
+ d_ff=config.enc_d_ff,
444
+ max_len=config.max_text_len,
445
+ dropout=config.dropout,
446
+ )
447
+
448
+ # enc_d == dec_d → identity projection (no extra params)
449
+ assert config.enc_d_model == config.dec_d_model, \
450
+ f"V6 requires enc_d == dec_d, got {config.enc_d_model} vs {config.dec_d_model}"
451
+
452
+ # Audio decoder (causal with cross-attention + speaker embedding)
453
+ self.decoder = AudioDecoder(
454
+ vocab_size=config.dec_vocab_size,
455
+ d_model=config.dec_d_model,
456
+ n_heads=config.dec_n_heads,
457
+ n_layers=config.dec_n_layers,
458
+ d_ff=config.dec_d_ff,
459
+ max_len=config.max_audio_len,
460
+ dropout=config.dropout,
461
+ speaker_emb_dim=config.speaker_emb_dim,
462
+ )
463
+
464
+ self.apply(self._init_weights)
465
+
466
+ def _init_weights(self, module):
467
+ if isinstance(module, nn.Linear):
468
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
469
+ if module.bias is not None:
470
+ nn.init.zeros_(module.bias)
471
+ elif isinstance(module, nn.Embedding):
472
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
473
+
474
+ def get_num_params(self) -> int:
475
+ return sum(p.numel() for p in self.parameters())
476
+
477
+ def encode(self, enc_ids, enc_mask=None):
478
+ """Run encoder. Returns [B, T_enc, d_model]."""
479
+ return self.encoder(enc_ids, enc_mask)
480
+
481
+ def forward(self, enc_ids, dec_ids, enc_mask=None, dec_labels=None,
482
+ speaker_emb=None):
483
+ """
484
+ Full forward: encoder → decoder → loss.
485
+
486
+ Args:
487
+ enc_ids: [B, T_enc] — text token IDs
488
+ dec_ids: [B, T_dec] — audio token IDs (decoder input)
489
+ enc_mask: [B, T_enc] — 1=real, 0=pad
490
+ dec_labels: [B, T_dec] — decoder labels (-100 for masked)
491
+ speaker_emb: [B, 128] — MioCodec global_embedding
492
+ """
493
+ # 1. Encode text
494
+ enc_out = self.encoder(enc_ids, enc_mask) # [B, T_enc, d_model]
495
+
496
+ # 2. Decode audio with cross-attention + speaker
497
+ dec_out = self.decoder(dec_ids, enc_out, enc_mask,
498
+ speaker_emb=speaker_emb, labels=dec_labels)
499
+
500
+ result = {"logits": dec_out["logits"]}
501
+ if "loss" in dec_out:
502
+ result["loss"] = dec_out["loss"]
503
+
504
+ return result
505
+
506
+
507
+ # ── Factory functions ──────────────────────────────────────────
508
+
509
+ def create_model(device="cuda", dropout_override=None) -> TTSEncoderDecoder:
510
+ """Create V6 encoder-decoder TTS model."""
511
+ kwargs = {}
512
+ if dropout_override is not None:
513
+ kwargs["dropout"] = dropout_override
514
+ config = V6Config(**kwargs)
515
+ model = TTSEncoderDecoder(config)
516
+
517
+ n = model.get_num_params()
518
+ enc_n = sum(p.numel() for p in model.encoder.parameters())
519
+ dec_n = sum(p.numel() for p in model.decoder.parameters())
520
+
521
+ print(f"V6 Encoder-Decoder TTS with MioCodec + Speaker Embedding")
522
+ print(f" Total params: {n:,} ({n/1e6:.1f}M)")
523
+ print(f" Encoder: {enc_n:,} ({enc_n/1e6:.1f}M)")
524
+ print(f" Decoder: {dec_n:,} ({dec_n/1e6:.1f}M)")
525
+ print(f" Enc: d={config.enc_d_model}, h={config.enc_n_heads}, "
526
+ f"L={config.enc_n_layers}, ff={config.enc_d_ff}")
527
+ print(f" Dec: d={config.dec_d_model}, h={config.dec_n_heads}, "
528
+ f"L={config.dec_n_layers}, ff={config.dec_d_ff}")
529
+ print(f" Speaker: {config.speaker_emb_dim}-dim → {config.dec_d_model}")
530
+ print(f" Dropout: {config.dropout}")
531
+
532
+ model = model.to(device)
533
+ return model
534
+
535
+
536
+ def save_checkpoint(model, optimizer, scheduler, step, loss, path, best_val_loss=None):
537
+ """Save full training checkpoint."""
538
+ os.makedirs(path, exist_ok=True)
539
+ model_to_save = model._orig_mod if hasattr(model, "_orig_mod") else model
540
+
541
+ torch.save({
542
+ "model_state_dict": model_to_save.state_dict(),
543
+ "optimizer_state_dict": optimizer.state_dict(),
544
+ "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
545
+ "step": step,
546
+ "loss": loss,
547
+ "best_val_loss": best_val_loss,
548
+ "config": {
549
+ "enc_vocab_size": model_to_save.config.enc_vocab_size,
550
+ "enc_d_model": model_to_save.config.enc_d_model,
551
+ "enc_n_heads": model_to_save.config.enc_n_heads,
552
+ "enc_n_layers": model_to_save.config.enc_n_layers,
553
+ "enc_d_ff": model_to_save.config.enc_d_ff,
554
+ "max_text_len": model_to_save.config.max_text_len,
555
+ "dec_vocab_size": model_to_save.config.dec_vocab_size,
556
+ "dec_d_model": model_to_save.config.dec_d_model,
557
+ "dec_n_heads": model_to_save.config.dec_n_heads,
558
+ "dec_n_layers": model_to_save.config.dec_n_layers,
559
+ "dec_d_ff": model_to_save.config.dec_d_ff,
560
+ "max_audio_len": model_to_save.config.max_audio_len,
561
+ "speaker_emb_dim": model_to_save.config.speaker_emb_dim,
562
+ "dropout": model_to_save.config.dropout,
563
+ },
564
+ }, f"{path}/checkpoint.pt")
565
+ print(f"Saved: {path} (step {step}, loss {loss:.4f})")
566
+
567
+
568
+ def load_for_inference(checkpoint_path: str, device="cuda") -> TTSEncoderDecoder:
569
+ """Load model from checkpoint for inference."""
570
+ if os.path.isfile(checkpoint_path):
571
+ ckpt_file = checkpoint_path
572
+ else:
573
+ ckpt_file = os.path.join(checkpoint_path, "checkpoint.pt")
574
+ print(f"Loading from {ckpt_file}...")
575
+ ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
576
+
577
+ cfg = ckpt["config"]
578
+ config = V6Config(
579
+ enc_vocab_size=cfg["enc_vocab_size"],
580
+ enc_d_model=cfg["enc_d_model"],
581
+ enc_n_heads=cfg["enc_n_heads"],
582
+ enc_n_layers=cfg["enc_n_layers"],
583
+ enc_d_ff=cfg["enc_d_ff"],
584
+ max_text_len=cfg["max_text_len"],
585
+ dec_vocab_size=cfg["dec_vocab_size"],
586
+ dec_d_model=cfg["dec_d_model"],
587
+ dec_n_heads=cfg["dec_n_heads"],
588
+ dec_n_layers=cfg["dec_n_layers"],
589
+ dec_d_ff=cfg["dec_d_ff"],
590
+ max_audio_len=cfg["max_audio_len"],
591
+ speaker_emb_dim=cfg.get("speaker_emb_dim", SPEAKER_EMB_DIM),
592
+ dropout=cfg["dropout"],
593
+ )
594
+ model = TTSEncoderDecoder(config)
595
+ model.load_state_dict(ckpt["model_state_dict"])
596
+ model = model.to(device).eval()
597
+
598
+ n = model.get_num_params()
599
+ print(f"Loaded! {n/1e6:.1f}M params, step {ckpt['step']}, loss {ckpt['loss']:.4f}")
600
+ return model
BgTTS/server.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BgTTS-38M Web Server — Gradio Interface
3
+ ========================================
4
+ Voice cloning TTS with Bulgarian + English support.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+ import torch
10
+ import numpy as np
11
+ import tempfile
12
+ import time
13
+ import soundfile as sf
14
+
15
+ # Add parent dir to path for imports
16
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
17
+
18
+ from config import (
19
+ AUDIO_OFFSET, NUM_AUDIO_TOKENS, END_OF_SPEECH_TOKEN_ID,
20
+ START_OF_SPEECH_TOKEN_ID, CODEC_SAMPLE_RATE, CODEC_FRAME_RATE,
21
+ )
22
+ from tokenizer import TTSTokenizer
23
+ from codec import CodecV6
24
+ from model import load_for_inference
25
+ from inference import generate, _split_text
26
+
27
+ # ── Global state ──────────────────────────────────────────────
28
+ MODEL = None
29
+ TOKENIZER = None
30
+ CODEC = None
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+ CHECKPOINT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoint_inference.pt")
33
+
34
+
35
+ def load_model():
36
+ """Load model, tokenizer, codec once at startup."""
37
+ global MODEL, TOKENIZER, CODEC
38
+ print(f"Loading model from {CHECKPOINT_PATH} on {DEVICE}...")
39
+ MODEL = load_for_inference(CHECKPOINT_PATH, device=DEVICE)
40
+ TOKENIZER = TTSTokenizer()
41
+ CODEC = CodecV6(device=DEVICE)
42
+ print("Model loaded!")
43
+
44
+
45
+ def synthesize_speech(text, ref_audio, temperature, top_k, top_p, rep_penalty):
46
+ """
47
+ Generate speech from text using reference audio for voice cloning.
48
+
49
+ Returns: (sample_rate, audio_array) tuple for Gradio
50
+ """
51
+ if not text or not text.strip():
52
+ return None
53
+
54
+ if ref_audio is None:
55
+ return None
56
+
57
+ # Encode reference audio for speaker embedding
58
+ sr_ref, audio_ref = ref_audio
59
+ audio_ref = audio_ref.astype(np.float32)
60
+ if audio_ref.max() > 1.0 or audio_ref.min() < -1.0:
61
+ audio_ref = audio_ref / max(abs(audio_ref.max()), abs(audio_ref.min()))
62
+
63
+ waveform = torch.from_numpy(audio_ref)
64
+ if waveform.dim() == 2:
65
+ waveform = waveform.mean(1)
66
+
67
+ result = CODEC.encode_waveform(waveform, sr_ref)
68
+ speaker_emb = result['global_embedding'].to(DEVICE)
69
+
70
+ # Split text into chunks
71
+ chunks = _split_text(text, TOKENIZER, max_len=250)
72
+
73
+ t0 = time.time()
74
+ all_codes = []
75
+ for chunk in chunks:
76
+ codes = generate(
77
+ MODEL, TOKENIZER, chunk, speaker_emb,
78
+ max_new_tokens=512,
79
+ temperature=temperature,
80
+ top_k=int(top_k),
81
+ top_p=top_p,
82
+ rep_penalty=rep_penalty,
83
+ device=DEVICE
84
+ )
85
+ if codes is not None and len(codes) > 0:
86
+ all_codes.append(codes)
87
+
88
+ gen_time = time.time() - t0
89
+
90
+ if not all_codes:
91
+ return None
92
+
93
+ codes = torch.cat(all_codes)
94
+ audio_dur = len(codes) / CODEC_FRAME_RATE
95
+ rtf = gen_time / audio_dur if audio_dur > 0 else float('inf')
96
+
97
+ # Decode to waveform
98
+ wav = CODEC.decode(codes, speaker_emb)
99
+ wav_np = wav.numpy()
100
+
101
+ info = f"✅ {len(codes)} tokens | {audio_dur:.1f}s audio | {gen_time:.1f}s gen | RTF: {rtf:.3f}"
102
+
103
+ return (CODEC_SAMPLE_RATE, wav_np), info
104
+
105
+
106
+ def build_ui():
107
+ """Build Gradio interface."""
108
+ import gradio as gr
109
+
110
+ with gr.Blocks(
111
+ title="BgTTS-38M — Bulgarian Text-to-Speech",
112
+ theme=gr.themes.Soft(
113
+ primary_hue="blue",
114
+ secondary_hue="slate",
115
+ ),
116
+ css="""
117
+ .main-title { text-align: center; margin-bottom: 0.5em; }
118
+ .subtitle { text-align: center; color: #666; margin-bottom: 1.5em; }
119
+ """
120
+ ) as app:
121
+ gr.HTML('<h1 class="main-title">🎙️ BgTTS-38M</h1>')
122
+ gr.HTML('<p class="subtitle">Bulgarian + English Text-to-Speech with Voice Cloning | 38M params | 153MB</p>')
123
+
124
+ with gr.Row():
125
+ with gr.Column(scale=2):
126
+ text_input = gr.Textbox(
127
+ label="Текст / Text",
128
+ placeholder="Въведете текст на български или английски...\nEnter text in Bulgarian or English...",
129
+ lines=5,
130
+ max_lines=15,
131
+ )
132
+
133
+ ref_audio = gr.Audio(
134
+ label="🎤 Reference Voice (за клониране на глас)",
135
+ type="numpy",
136
+ sources=["upload", "microphone"],
137
+ )
138
+
139
+ with gr.Row():
140
+ generate_btn = gr.Button("🔊 Генерирай / Generate", variant="primary", size="lg")
141
+ clear_btn = gr.Button("🗑️ Изчисти", size="lg")
142
+
143
+ with gr.Column(scale=1):
144
+ with gr.Accordion("⚙️ Настройки / Settings", open=False):
145
+ temperature = gr.Slider(
146
+ minimum=0.05, maximum=1.5, value=0.3, step=0.05,
147
+ label="Temperature",
148
+ info="По-ниска = по-чисто, по-висока = по-разнообразно"
149
+ )
150
+ top_k = gr.Slider(
151
+ minimum=1, maximum=500, value=250, step=10,
152
+ label="Top-K"
153
+ )
154
+ top_p = gr.Slider(
155
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05,
156
+ label="Top-P (Nucleus)"
157
+ )
158
+ rep_penalty = gr.Slider(
159
+ minimum=1.0, maximum=2.0, value=1.1, step=0.05,
160
+ label="Repetition Penalty"
161
+ )
162
+
163
+ output_audio = gr.Audio(
164
+ label="🔊 Резултат / Output",
165
+ type="numpy",
166
+ interactive=False,
167
+ )
168
+
169
+ info_text = gr.Textbox(
170
+ label="ℹ️ Информация",
171
+ interactive=False,
172
+ lines=2,
173
+ )
174
+
175
+ # Examples
176
+ gr.Examples(
177
+ examples=[
178
+ ["Българският език е изключително богат и мелодичен."],
179
+ ["Artificial intelligence has reached a fascinating stage."],
180
+ ["Когато говорим за истински multitasking, способността ми да превключвам плавно между български и English е от огромно значение."],
181
+ ["Здравейте! Казвам се Ани и мога да говоря на български и английски."],
182
+ ["The quick brown fox jumps over the lazy dog."],
183
+ ],
184
+ inputs=[text_input],
185
+ label="📝 Примери / Examples",
186
+ )
187
+
188
+ # Event handlers
189
+ generate_btn.click(
190
+ fn=synthesize_speech,
191
+ inputs=[text_input, ref_audio, temperature, top_k, top_p, rep_penalty],
192
+ outputs=[output_audio, info_text],
193
+ )
194
+
195
+ clear_btn.click(
196
+ fn=lambda: (None, None, ""),
197
+ outputs=[text_input, output_audio, info_text],
198
+ )
199
+
200
+ return app
201
+
202
+
203
+ if __name__ == "__main__":
204
+ import argparse
205
+ p = argparse.ArgumentParser()
206
+ p.add_argument("--checkpoint", default=CHECKPOINT_PATH)
207
+ p.add_argument("--host", default="0.0.0.0")
208
+ p.add_argument("--port", type=int, default=7860)
209
+ p.add_argument("--share", action="store_true")
210
+ p.add_argument("--device", default=DEVICE)
211
+ args = p.parse_args()
212
+
213
+ CHECKPOINT_PATH = args.checkpoint
214
+ DEVICE = args.device
215
+
216
+ load_model()
217
+ app = build_ui()
218
+ app.launch(
219
+ server_name=args.host,
220
+ server_port=args.port,
221
+ share=args.share,
222
+ )
BgTTS/tokenizer.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V6 Tokenizer — char-level for Bulgarian TTS with MioCodec
3
+ ==========================================================
4
+ Same character set as V5, but adapted for:
5
+ - MioCodec single codebook (no interleaving)
6
+ - Speaker embedding (no speaker tokens in encoder input)
7
+ """
8
+
9
+ import re
10
+ import torch
11
+ from typing import Optional
12
+
13
+ from config import (
14
+ TEXT_CHARS, TEXT_OFFSET, AUDIO_OFFSET,
15
+ SPECIAL_TOKENS, NUM_SPECIAL_TOKENS, CODEC_CODEBOOK_SIZE,
16
+ TOTAL_VOCAB_SIZE,
17
+ PAD_TOKEN_ID, START_OF_TEXT_TOKEN_ID, END_OF_TEXT_TOKEN_ID,
18
+ START_OF_SPEECH_TOKEN_ID, END_OF_SPEECH_TOKEN_ID,
19
+ is_audio_token, is_special_token, is_text_token,
20
+ )
21
+
22
+
23
+ class TTSTokenizer:
24
+ def __init__(self):
25
+ self.char2id: dict[str, int] = {}
26
+ self.id2char: dict[int, str] = {}
27
+ for i, ch in enumerate(TEXT_CHARS):
28
+ tid = TEXT_OFFSET + i
29
+ self.char2id[ch] = tid
30
+ self.id2char[tid] = ch
31
+
32
+ self._special_id_to_name = {v: k for k, v in SPECIAL_TOKENS.items()}
33
+ self.vocab_size = TOTAL_VOCAB_SIZE
34
+ self.text_vocab_size = len(TEXT_CHARS)
35
+
36
+ def normalize_text(self, text: str) -> str:
37
+ text = re.sub(r'\s+', ' ', text).strip()
38
+ text = re.sub(r'[–—]', '-', text)
39
+ text = re.sub(r'[«»„""]', '"', text)
40
+ return text
41
+
42
+ def encode_text(self, text: str) -> list[int]:
43
+ text = self.normalize_text(text)
44
+ return [self.char2id[ch] for ch in text if ch in self.char2id]
45
+
46
+ def decode_text(self, ids: list[int]) -> str:
47
+ return "".join(self.id2char.get(t, "") for t in ids if is_text_token(t))
48
+
49
+ # ── Encoder-Decoder methods ──────────────────────────────
50
+
51
+ def build_encoder_input(self, text: str) -> torch.Tensor:
52
+ """
53
+ Encoder input: <sot> text_chars <eot>
54
+ No speaker token — speaker info comes from embedding.
55
+ """
56
+ text_ids = self.encode_text(text)
57
+ seq = [START_OF_TEXT_TOKEN_ID] + text_ids + [END_OF_TEXT_TOKEN_ID]
58
+ return torch.tensor(seq, dtype=torch.long)
59
+
60
+ def build_decoder_input(self, audio_codes: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Decoder input: <sos> [audio_codes + AUDIO_OFFSET] <eos>
63
+ audio_codes: raw MioCodec codes in [0, 12799]
64
+ """
65
+ seq = (
66
+ [START_OF_SPEECH_TOKEN_ID]
67
+ + (audio_codes + AUDIO_OFFSET).tolist()
68
+ + [END_OF_SPEECH_TOKEN_ID]
69
+ )
70
+ return torch.tensor(seq, dtype=torch.long)
71
+
72
+ def build_decoder_prefix(self) -> torch.Tensor:
73
+ """For inference: just <sos> to start generation."""
74
+ return torch.tensor([START_OF_SPEECH_TOKEN_ID], dtype=torch.long)
75
+
76
+ def extract_audio_codes(self, sequence: torch.Tensor) -> Optional[torch.Tensor]:
77
+ """Extract raw MioCodec codes from a token sequence."""
78
+ mask = torch.tensor([is_audio_token(t.item()) for t in sequence])
79
+ if not mask.any():
80
+ return None
81
+ return sequence[mask] - AUDIO_OFFSET
82
+
83
+ def describe(self, seq: torch.Tensor, max_tok: int = 30) -> str:
84
+ parts = []
85
+ for t in seq[:max_tok]:
86
+ tid = t.item()
87
+ if is_special_token(tid):
88
+ parts.append(self._special_id_to_name.get(tid, f"<sp_{tid}>"))
89
+ elif is_text_token(tid):
90
+ ch = self.id2char.get(tid, "?")
91
+ parts.append(ch if ch != " " else "·")
92
+ elif is_audio_token(tid):
93
+ code = tid - AUDIO_OFFSET
94
+ parts.append(f"♪{code}")
95
+ else:
96
+ parts.append(f"?{tid}")
97
+ r = " ".join(parts)
98
+ if len(seq) > max_tok:
99
+ r += f" ... [{len(seq) - max_tok} more]"
100
+ return r
BgTTS/train.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import math
4
+ import csv
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ from tqdm import tqdm
9
+ from torch.amp import autocast
10
+
11
+ from config import (PAD_TOKEN_ID, START_OF_SPEECH_TOKEN_ID,
12
+ END_OF_SPEECH_TOKEN_ID, AUDIO_OFFSET)
13
+ from model import create_model, save_checkpoint
14
+ from tokenizer import TTSTokenizer
15
+
16
+ # ── Хиперпараметри ───────────────────────────────────────────────
17
+ PEAK_LR = 7e-5
18
+ START_LR = 0
19
+ MIN_LR = 5e-6
20
+ WEIGHT_DECAY = 0.01
21
+ EPOCHS = 20
22
+ BATCH_SIZE = 64
23
+ ACCUM_STEPS = 1 # Без accumulation
24
+ GRAD_CLIP = 1.0
25
+ CKPT_EVERY = 1000 # Checkpoint на всеки N optimizer стъпки
26
+ LOG_FILE = "train_log.csv"
27
+
28
+ # ── Dataset ──────────────────────────────────────────────────────
29
+ class ShardedTTSDataset(Dataset):
30
+ def __init__(self, data_dir):
31
+ self.shard_files = sorted(glob.glob(os.path.join(data_dir, "*.pt")))
32
+ self.samples = []
33
+ print(f"Зареждане на {len(self.shard_files)} шарда...")
34
+ for sf in self.shard_files:
35
+ self.samples.extend(torch.load(sf, weights_only=False))
36
+ print(f"Общо записи: {len(self.samples):,}")
37
+
38
+ def __len__(self):
39
+ return len(self.samples)
40
+
41
+ def __getitem__(self, idx):
42
+ item = self.samples[idx]
43
+ return {
44
+ 'text_ids': item['text_ids'].clone().detach().long(),
45
+ 'audio_codes': item['audio_codes'].clone().detach().long(),
46
+ 'speaker_emb': item['speaker_emb'].clone().detach().float(),
47
+ }
48
+
49
+ def collate_fn(batch):
50
+ enc_ids_list, dec_ids_list, labels_list, speaker_embs = [], [], [], []
51
+ for item in batch:
52
+ enc_ids_list.append(item['text_ids'])
53
+ audio_codes = item['audio_codes'] + AUDIO_OFFSET
54
+ # GPT-style: model.py вътрешно shift-ва logits[:, :-1] vs labels[:, 1:]
55
+ # Затова dec_ids и labels трябва да са подравнени, а model-ът сам измества.
56
+ dec_ids_list.append(torch.cat([torch.tensor([START_OF_SPEECH_TOKEN_ID]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])]))
57
+ labels_list.append(torch.cat([torch.tensor([-100]), audio_codes, torch.tensor([END_OF_SPEECH_TOKEN_ID])]))
58
+ speaker_embs.append(item['speaker_emb'])
59
+
60
+ enc_ids = pad_sequence(enc_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID)
61
+ dec_ids = pad_sequence(dec_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID)
62
+ labels = pad_sequence(labels_list, batch_first=True, padding_value=-100)
63
+ enc_mask = (enc_ids != PAD_TOKEN_ID).long()
64
+ speaker_emb = torch.stack(speaker_embs)
65
+ return enc_ids, dec_ids, enc_mask, labels, speaker_emb
66
+
67
+ # ── LR Scheduler: Warmup + Cosine Decay ─────────────────────────
68
+ def get_lr(step: int, warmup_steps: int, total_steps: int) -> float:
69
+ if step < warmup_steps:
70
+ return START_LR + (PEAK_LR - START_LR) * (step / max(1, warmup_steps))
71
+ else:
72
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
73
+ cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
74
+ return MIN_LR + (PEAK_LR - MIN_LR) * cosine
75
+
76
+ # ── Основен тренировъчен цикъл ───────────────────────────────────
77
+ def train():
78
+ device = "cuda" if torch.cuda.is_available() else "cpu"
79
+ print(f"Устройство: {device}")
80
+
81
+ processed_dir = os.path.abspath("../data/processed")
82
+ if not os.path.exists(processed_dir):
83
+ print(f"[ГРЕШКА] {processed_dir} не съществува!"); return
84
+
85
+ dataset = ShardedTTSDataset(processed_dir)
86
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
87
+ collate_fn=collate_fn, num_workers=4, pin_memory=True)
88
+
89
+ steps_per_epoch = len(dataloader) // ACCUM_STEPS # optimizer стъпки на епоха
90
+ warmup_steps = steps_per_epoch * 2 # Warmup = 2 епохи
91
+ total_steps = steps_per_epoch * EPOCHS
92
+ print(f"Батчове/епоха: {len(dataloader):,} | Optimizer стъпки/епоха: {steps_per_epoch:,} | Accum: {ACCUM_STEPS}")
93
+ print(f"Warmup: {warmup_steps:,} стъпки (2 епохи) | Общо: {total_steps:,}")
94
+ print(f"Peak LR: {PEAK_LR}, Min LR: {MIN_LR}, Weight Decay: {WEIGHT_DECAY}, Epochs: {EPOCHS}")
95
+ print(f"Ефективен batch size: {BATCH_SIZE * ACCUM_STEPS}")
96
+
97
+ model = create_model(device=device)
98
+ model.train()
99
+ optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LR, weight_decay=WEIGHT_DECAY,
100
+ betas=(0.9, 0.999), eps=1e-8)
101
+ # BF16 — без GradScaler (не е нужен при bfloat16)
102
+
103
+ os.makedirs("checkpoints", exist_ok=True)
104
+
105
+ # CSV лог за реално наблюдение
106
+ log_path = LOG_FILE
107
+ log_f = open(log_path, "w", newline="")
108
+ writer = csv.writer(log_f)
109
+ writer.writerow(["step", "batch_loss", "avg_loss", "lr"])
110
+ log_f.flush()
111
+ print(f"Loss лог: {log_path} (следи с: tail -f {log_path})\n")
112
+
113
+ step = 0
114
+ running_loss = 0.0
115
+ running_count = 0
116
+
117
+ for epoch in range(EPOCHS):
118
+ loop = tqdm(total=steps_per_epoch, desc=f"Епоха {epoch+1}/{EPOCHS}")
119
+ epoch_loss_sum, valid_batches = 0.0, 0
120
+
121
+ optimizer.zero_grad(set_to_none=True)
122
+ for i, (enc_ids, dec_ids, enc_mask, labels, spk_emb) in enumerate(dataloader):
123
+ enc_ids = enc_ids.to(device)
124
+ dec_ids = dec_ids.to(device)
125
+ enc_mask = enc_mask.to(device)
126
+ labels = labels.to(device)
127
+ spk_emb = spk_emb.to(device)
128
+
129
+ with autocast('cuda', dtype=torch.bfloat16):
130
+ out = model(enc_ids=enc_ids, dec_ids=dec_ids,
131
+ enc_mask=enc_mask, dec_labels=labels,
132
+ speaker_emb=spk_emb)
133
+ loss = out['loss'] / ACCUM_STEPS
134
+
135
+ loss.backward()
136
+
137
+ batch_loss = loss.item() * ACCUM_STEPS # реалният loss
138
+ epoch_loss_sum += batch_loss
139
+ valid_batches += 1
140
+
141
+ if (i + 1) % ACCUM_STEPS == 0:
142
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
143
+ optimizer.step()
144
+ optimizer.zero_grad(set_to_none=True)
145
+ step += 1
146
+
147
+ current_lr = get_lr(step, warmup_steps, total_steps)
148
+ for pg in optimizer.param_groups:
149
+ pg['lr'] = current_lr
150
+
151
+ running_loss += batch_loss
152
+ running_count += 1
153
+ avg_loss = running_loss / running_count
154
+
155
+ writer.writerow([step, f"{batch_loss:.4f}", f"{avg_loss:.4f}", f"{current_lr:.2e}"])
156
+ log_f.flush()
157
+
158
+ loop.update(1)
159
+ loop.set_postfix(step=step, loss=f"{batch_loss:.4f}",
160
+ avg=f"{avg_loss:.4f}", lr=f"{current_lr:.2e}")
161
+
162
+ if step % CKPT_EVERY == 0:
163
+ ckpt_dir = f"checkpoints/step_{step:06d}"
164
+ save_checkpoint(model, optimizer, None, step,
165
+ avg_loss, ckpt_dir, best_val_loss=None)
166
+ tqdm.write(f" ✓ Checkpoint запазен: {ckpt_dir} | step={step} | avg_loss={avg_loss:.4f}")
167
+
168
+ loop.close()
169
+ epoch_avg = epoch_loss_sum / max(1, valid_batches)
170
+ ckpt_dir = f"checkpoints/epoch_{epoch+1}_final"
171
+ save_checkpoint(model, optimizer, None, step, epoch_avg, ckpt_dir, best_val_loss=None)
172
+ print(f"\n✓ Епоха {epoch+1} завърши. Средна загуба: {epoch_avg:.4f}")
173
+ print(f" Checkpoint: {ckpt_dir}")
174
+
175
+ log_f.close()
176
+ print("\n[КРАЙ] Обучението приключи успешно!")
177
+
178
+ if __name__ == "__main__":
179
+ train()
README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - bg
4
+ license: mit
5
+ tags:
6
+ - text-to-speech
7
+ - tts
8
+ - bulgarian
9
+ - fastapi
10
+ pipeline_tag: text-to-speech
11
+ ---
12
+ # Ani Voice API
13
+
14
+ Завършен TTS (Text-to-Speech) пакет за български език, базиран на BgTTS и Supertonic, обвит в гъвкаво FastAPI приложение.
15
+
16
+ *Проектът е създаден и разработен от **Ani-Antigravity** по идея и желание на **Наско (@beleata74)**.*
17
+
18
+ ## Инсталация
19
+
20
+ 1. Уверете се, че имате Python 3.10+
21
+ 2. Инсталирайте нужните зависимости:
22
+ ```bash
23
+ pip install -r requirements.txt
24
+ ```
25
+
26
+ ## Стартиране на API сървъра
27
+
28
+ ```bash
29
+ python api.py
30
+ ```
31
+ Сървърът ще тръгне на `http://localhost:8000`. Можете да разгледате автоматичната документация на `http://localhost:8000/docs`.
32
+
33
+ ## Използване
34
+
35
+ ### 1. Генериране на цял аудио файл
36
+ Изпраща текст и връща завършен `.wav` файл.
37
+
38
+ **Пример:**
39
+ ```bash
40
+ curl -X POST "http://localhost:8000/api/v1/synthesize" \
41
+ -H "Content-Type: application/json" \
42
+ -d '{"text": "Здравей, свят!", "voice_style": "F5", "speed": 1.6}' \
43
+ --output response.wav
44
+ ```
45
+
46
+ ### 2. Стрийминг на аудио (NDJSON)
47
+ Изпраща аудиото на малки парчета (chunks), докато се генерират, кодирани в base64. Полезно за дълги текстове, където искате да пускате аудиото веднага.
48
+
49
+ Връща редове във формат:
50
+ ```json
51
+ {"chunk_index": 0, "audio_base64": "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAEA..."}
52
+ ```
53
+
54
+ Вижте файла `client_example.py` за пример как да интегрирате API-то в Python код.
55
+ Вижте файла `voice_pipeline.py` за пример на работещ клиент-демон (daemon), който комуникира с API-то и пуска звука в реално време!
56
+
57
+ ## Аудио Демонстрации
58
+
59
+ В репозиторито можете да намерите няколко предварително генерирани аудио файла, за да чуете как звучи моделът:
60
+
61
+ 1. **`demo1_conversation.wav`**
62
+ - *Транскрипция:* "Здравейте! Това е тестов запис от нашия нов български TTS модел. Надявам се да ви хареса как звучи гласът ми!"
63
+ 2. **`demo2_numbers.wav`** (Демонстрира нормализацията на числа и дати)
64
+ - *Транскрипция:* "Днес е 15 май 2026 година. Температурата навън е 23.5 градуса, а вятърът духа със скорост 5.4 километра в час. Цената е 1500 лв. за м²."
65
+ 3. **`demo3_expressive.wav`**
66
+ - *Транскрипция:* "Супер! Наистина много се радвам, че всичко най-накрая работи гладко. Усилията определено си заслужаваха!"
api.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import Response, StreamingResponse
3
+ from pydantic import BaseModel
4
+ import uvicorn
5
+
6
+ # Импортираме tts_engine - това автоматично ще зареди моделите в паметта при стартиране!
7
+ from tts_engine import engine
8
+
9
+ app = FastAPI(title="Ani Voice API", version="1.0.0")
10
+
11
+ class SynthesizeRequest(BaseModel):
12
+ text: str
13
+ voice_style: str = "F5"
14
+ speed: float = 1.6
15
+
16
+ @app.post("/api/v1/synthesize")
17
+ def synthesize_full_audio(request: SynthesizeRequest):
18
+ """
19
+ Генерира аудио за целия текст и го връща като един WAV файл.
20
+ Подходящо за кратки съобщения.
21
+ """
22
+ try:
23
+ audio_bytes = engine.synthesize_full(request.text, request.voice_style, request.speed)
24
+ if not audio_bytes:
25
+ raise HTTPException(status_code=400, detail="Неуспешно генериране на аудио (празен текст?).")
26
+
27
+ return Response(content=audio_bytes, media_type="audio/wav")
28
+ except Exception as e:
29
+ raise HTTPException(status_code=500, detail=str(e))
30
+
31
+ import base64
32
+ import json
33
+
34
+ @app.post("/api/v1/synthesize/stream")
35
+ def synthesize_stream_audio(request: SynthesizeRequest):
36
+ """
37
+ Стрийминг endpoint, който връща аудио на парчета (chunks).
38
+ Всеки ред е JSON обект: {"chunk_index": i, "audio_base64": "..."}
39
+ """
40
+ def generate():
41
+ try:
42
+ for i, audio_bytes in enumerate(engine.synthesize_stream(request.text, request.voice_style, request.speed)):
43
+ encoded = base64.b64encode(audio_bytes).decode("utf-8")
44
+ yield json.dumps({"chunk_index": i, "audio_base64": encoded}) + "\n"
45
+ except Exception as e:
46
+ print(f"Грешка по време на стрийминг: {e}")
47
+ yield json.dumps({"error": str(e)}) + "\n"
48
+
49
+ return StreamingResponse(generate(), media_type="application/x-ndjson")
50
+
51
+ if __name__ == "__main__":
52
+ print("Стартиране на Ani Voice API сървър на порт 8000...")
53
+ uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=False)
client_example.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ API_URL = "http://localhost:8000/api/v1/synthesize"
4
+
5
+ def synthesize_text(text: str, output_file: str):
6
+ """
7
+ Изпраща текст към API-то и запазва резултата като WAV файл.
8
+ """
9
+ print(f"Изпращане на заявка за: '{text}'...")
10
+
11
+ response = requests.post(API_URL, json={
12
+ "text": text,
13
+ "voice_style": "F5",
14
+ "speed": 1.6
15
+ })
16
+
17
+ if response.status_code == 200:
18
+ with open(output_file, "wb") as f:
19
+ f.write(response.content)
20
+ print(f"✅ Аудиото е запазено успешно в: {output_file}")
21
+ else:
22
+ print(f"❌ Грешка: {response.status_code} - {response.text}")
23
+
24
+ if __name__ == "__main__":
25
+ text_to_say = "Здравей! Това е тестов запис, създаден чрез новото Ani Voice API."
26
+ output_filename = "test_api_output.wav"
27
+ synthesize_text(text_to_say, output_filename)
demo1_conversation.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5395b9f0d2f01685b82c1592dd236999deec3de7e2a4e1a3ab25611a8f1d01d6
3
+ size 332204
demo2_numbers.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7cdfc164606f698053e9264117e15093f4d15d0d2d7621eee131d03141a4130
3
+ size 762284
demo3_expressive.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed299d60ff6f02ed78d7bce2322fb949e74daf0fbecfb3eae8775425943a36d7
3
+ size 318764
normalizer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from bg_text_normalizer import normalize_text as bg_norm
3
+
4
+ # Допълнителни специфични замени, които bg_text_normalizer изпуска
5
+ EXTRA_ABBREVIATIONS = {
6
+ r"\bм²\b": "квадратен метър",
7
+ r"\bкв\.м\.\b": "квадратен метър",
8
+ r"\bт\.е\.\b": "тоест",
9
+ }
10
+
11
+ def normalize_text(text: str) -> str:
12
+ """
13
+ Нормализира текста, използвайки bg-text-normalizer + наши специфични правила.
14
+ """
15
+ # 0.5 Предварителна обработка на десетични дроби: заменяме точката със запетая
16
+ # bg-text-normalizer бърка '1.4' с '1 април'. За да го чете като дроб, му трябва запетая '1,4'.
17
+ text = re.sub(r'(\d)\.(\d)', r'\1,\2', text)
18
+
19
+ # 1. Първо прилагаме библиотеката bg_text_normalizer
20
+ text = bg_norm(text)
21
+
22
+ # 2. Оправяме точките след съкращения като "лв." и "гр.", които библиотеката е превърнала в "лева."
23
+ text = text.replace("лева.", "лева")
24
+ text = text.replace("стотинки.", "стотинки")
25
+
26
+ # 3. Прилагаме нашите допълнителни правила
27
+ for pattern, replacement in EXTRA_ABBREVIATIONS.items():
28
+ text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
29
+
30
+ # Допълнително хващаме м² (без \b, защото ² не е дума)
31
+ text = text.replace("м²", "квадратен метър")
32
+
33
+ # Махане на двойни интервали
34
+ text = re.sub(r"\s+", " ", text).strip()
35
+
36
+ return text
37
+
38
+ if __name__ == "__main__":
39
+ test_text = "Цената е 1500 лв. за м² в кв. Лозенец."
40
+ print("Original:", test_text)
41
+ print("Normalized:", normalize_text(test_text))
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.13.0.dev20260603+cu132
2
+ torchaudio==2.11.0.dev20260608+cu132
3
+ torchvision==0.28.0.dev20260608+cu132
4
+ numpy<2.0.0
5
+ supertonic==1.3.1
6
+ bg-text-normalizer==1.1.0
7
+ num2cyrillic==1.0.0
8
+ fastapi>=0.110.0
9
+ uvicorn>=0.29.0
10
+ pydantic>=2.7.0
11
+ requests>=2.31.0
tts_engine.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import re
4
+ import wave
5
+ import torch
6
+ import numpy as np
7
+ import tempfile
8
+ import sys
9
+ import supertonic
10
+
11
+ # Добавяме BgTTS към sys.path, за да може вътрешните му импорти да работят
12
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'BgTTS'))
13
+ from inference import synthesize
14
+ from normalizer import normalize_text
15
+
16
+ class TTSEngine:
17
+ def __init__(self):
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"Зареждам TTS Engine на устройство: {self.device}")
20
+
21
+ # Supertonic (Референтно аудио)
22
+ from supertonic import TTS
23
+ self.engine = TTS(auto_download=True)
24
+
25
+ # BgTTS (Основен модел)
26
+ self.bgtts_checkpoint = os.path.join(os.path.dirname(__file__), "BgTTS", "checkpoint_inference.pt")
27
+ # BgTTS inference.synthesize зарежда модела всеки път, ако не му подадем модела.
28
+ # В текущия BgTTS/inference.py synthesize() вика load_for_inference(), ако се подаде път.
29
+ # За сега ще ползваме пътя, тъй като така е написан BgTTS.
30
+ # Ако искаме пълно кеширане, може да се наложи леко пренаписване на BgTTS/inference.py.
31
+ # Но засега ще ползваме оригиналната synthesize функция.
32
+
33
+ print("TTS Engine зареден успешно.")
34
+
35
+ def split_text_for_tts(self, text: str) -> list[str]:
36
+ text = text.strip()
37
+ if not text:
38
+ return []
39
+ raw = re.split(r'(?<=[\.\!\?…])\s+|\n+', text)
40
+ chunks = []
41
+ buf = ""
42
+ for part in raw:
43
+ part = part.strip()
44
+ if not part: continue
45
+
46
+ if not buf or len(buf) < 80 or len(buf) + len(part) + 1 <= 200:
47
+ buf = (buf + " " + part).strip()
48
+ else:
49
+ chunks.append(buf)
50
+ buf = part
51
+ if buf: chunks.append(buf)
52
+ return chunks
53
+
54
+ def generate_chunk(self, chunk_text: str, voice_style: str = "F5", speed: float = 1.6) -> bytes:
55
+ """
56
+ Генерира аудио за едно изречение (chunk) и го връща като WAV байтове.
57
+ """
58
+ clean_text = chunk_text.replace('"', '').replace('„', '').replace('“', '') \
59
+ .replace("’", "'").replace("–", "-").replace("—", "-") \
60
+ .replace("*", "")
61
+
62
+ if not clean_text.strip():
63
+ return b""
64
+
65
+ # 1. Генериране на референтно аудио
66
+ # Ако voice_style е стринг (напр. "F5"), взимаме съответния обект
67
+ if isinstance(voice_style, str):
68
+ v_style = self.engine.get_voice_style(voice_name=voice_style)
69
+ else:
70
+ v_style = voice_style
71
+
72
+ wav_array, _ = self.engine.synthesize(clean_text, voice_style=v_style, lang="bg", speed=speed)
73
+ wav_data = np.asarray(wav_array).flatten()
74
+ wav_max = np.max(np.abs(wav_data))
75
+ if wav_max > 0:
76
+ wav_data = wav_data / wav_max
77
+ pcm_data = (wav_data * 32767).astype(np.int16)
78
+
79
+ # Записваме временно референтното аудио (тъй като BgTTS изисква файл)
80
+ fd, ref_path = tempfile.mkstemp(suffix=".wav")
81
+ os.close(fd)
82
+ with wave.open(ref_path, "wb") as wf:
83
+ wf.setnchannels(1)
84
+ wf.setsampwidth(2)
85
+ wf.setframerate(44100)
86
+ wf.writeframes(pcm_data.tobytes())
87
+
88
+ # 2. Генериране на крайното аудио
89
+ fd, final_path = tempfile.mkstemp(suffix=".wav")
90
+ os.close(fd)
91
+
92
+ try:
93
+ synthesize(checkpoint=self.bgtts_checkpoint,
94
+ text=clean_text,
95
+ output=final_path,
96
+ speaker_wav=ref_path,
97
+ device=self.device)
98
+
99
+ # Прочитане на резултата
100
+ with open(final_path, "rb") as f:
101
+ audio_bytes = f.read()
102
+
103
+ return audio_bytes
104
+
105
+ finally:
106
+ try:
107
+ os.remove(ref_path)
108
+ os.remove(final_path)
109
+ except OSError:
110
+ pass
111
+
112
+ def synthesize_stream(self, text: str, voice_style: str = "F5", speed: float = 1.6):
113
+ """
114
+ Генератор, който нормализира текста, цепи го на парчета и връща WAV байтове за всяко парче.
115
+ """
116
+ normalized_text = normalize_text(text)
117
+ chunks = self.split_text_for_tts(normalized_text)
118
+
119
+ for chunk in chunks:
120
+ audio_bytes = self.generate_chunk(chunk, voice_style, speed)
121
+ if audio_bytes:
122
+ yield audio_bytes
123
+
124
+ def synthesize_full(self, text: str, voice_style: str = "F5", speed: float = 1.6) -> bytes:
125
+ """
126
+ Нормализира текста, цепи го, генерира всички парчета и ги слепва в един общ WAV файл.
127
+ """
128
+ normalized_text = normalize_text(text)
129
+ chunks = self.split_text_for_tts(normalized_text)
130
+
131
+ all_frames = b""
132
+ params = None
133
+
134
+ for chunk in chunks:
135
+ audio_bytes = self.generate_chunk(chunk, voice_style, speed)
136
+ if not audio_bytes:
137
+ continue
138
+
139
+ # Парсване на WAV данните, за да можем да ги слеем без да дублираме хедъри
140
+ with wave.open(io.BytesIO(audio_bytes), "rb") as wf:
141
+ if not params:
142
+ params = wf.getparams()
143
+ all_frames += wf.readframes(wf.getnframes())
144
+
145
+ if not params:
146
+ return b""
147
+
148
+ # Създаване на крайния WAV
149
+ out_io = io.BytesIO()
150
+ with wave.open(out_io, "wb") as wf:
151
+ wf.setparams(params)
152
+ wf.writeframes(all_frames)
153
+
154
+ return out_io.getvalue()
155
+
156
+ # Глобална инстанция за по-лесно преизползване
157
+ engine = TTSEngine()
voice_pipeline.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import base64
5
+ import requests
6
+ import tempfile
7
+ import threading
8
+ import queue
9
+ import winsound
10
+
11
+ TRANSCRIPT_PATH = r"C:\Users\nasko\.gemini\antigravity\brain\695897cf-1c58-4886-a686-e9d8c406ebef\.system_generated\logs\transcript.jsonl"
12
+ API_URL = "http://localhost:8000/api/v1/synthesize/stream"
13
+
14
+ audio_queue = queue.Queue()
15
+
16
+ def player_worker():
17
+ """
18
+ Взима готови WAV файлове от опашката и ги пуска.
19
+ """
20
+ while True:
21
+ file_path = audio_queue.get()
22
+ if file_path is None: break
23
+
24
+ print(f"🔊 Възпроизвеждам от API...")
25
+ winsound.PlaySound(file_path, winsound.SND_FILENAME)
26
+
27
+ try:
28
+ os.remove(file_path)
29
+ except OSError:
30
+ pass
31
+
32
+ audio_queue.task_done()
33
+
34
+ def process_text(text: str):
35
+ """
36
+ Изпраща текста към API-то и чака за стрийминг на аудио парчета.
37
+ """
38
+ print(f"📡 Изпращане към API: {text[:50]}...")
39
+ try:
40
+ response = requests.post(API_URL, json={
41
+ "text": text,
42
+ "voice_style": "F5",
43
+ "speed": 1.6
44
+ }, stream=True)
45
+
46
+ if response.status_code != 200:
47
+ print(f"Грешка от API: {response.status_code} - {response.text}")
48
+ return
49
+
50
+ for line in response.iter_lines():
51
+ if line:
52
+ data = json.loads(line)
53
+ if "error" in data:
54
+ print(f"API Грешка: {data['error']}")
55
+ continue
56
+
57
+ chunk_index = data.get("chunk_index")
58
+ audio_base64 = data.get("audio_base64")
59
+
60
+ if audio_base64:
61
+ audio_bytes = base64.b64decode(audio_base64)
62
+
63
+ # Записваме временно файла и го пускаме в опашката
64
+ fd, file_path = tempfile.mkstemp(suffix=f"_chunk_{chunk_index}.wav")
65
+ os.close(fd)
66
+
67
+ with open(file_path, "wb") as f:
68
+ f.write(audio_bytes)
69
+
70
+ audio_queue.put(file_path)
71
+ except requests.exceptions.ConnectionError:
72
+ print("Не мога да се свържа с API-то! Увери се, че `api.py` работи на порт 8000.")
73
+ except Exception as e:
74
+ print(f"Грешка при комуникация с API: {e}")
75
+
76
+ def tail_file():
77
+ """
78
+ Следи чата (transcript.jsonl) за нови съобщения от модела.
79
+ """
80
+ if not os.path.exists(TRANSCRIPT_PATH):
81
+ print(f"Файлът не съществува: {TRANSCRIPT_PATH}")
82
+ return
83
+
84
+ with open(TRANSCRIPT_PATH, "r", encoding="utf-8") as f:
85
+ f.seek(0, 2)
86
+
87
+ while True:
88
+ line = f.readline()
89
+ if not line:
90
+ time.sleep(0.5)
91
+ continue
92
+
93
+ try:
94
+ data = json.loads(line)
95
+ if data.get("source") == "MODEL" and data.get("type") in ["PLANNER_RESPONSE", "GENERIC"]:
96
+ full_text = data.get("content", "")
97
+ if full_text and not full_text.startswith("Created At:"):
98
+ print("\n📝 Получен нов текст от чата.")
99
+ process_text(full_text)
100
+ except Exception as e:
101
+ pass
102
+
103
+ if __name__ == "__main__":
104
+ t_play = threading.Thread(target=player_worker, daemon=True)
105
+ t_play.start()
106
+
107
+ print("Ani Voice Client слуша за съобщения и чака API-то...")
108
+ tail_file()