bnewton-genmedlabs commited on
Commit
1667f3a
·
verified ·
1 Parent(s): 4688879

Add full XTTS v2 architecture header with GPT-2 encoder and HiFi-GAN vocoder

Browse files
Files changed (1) hide show
  1. cpp/xtts_v2_full.h +396 -0
cpp/xtts_v2_full.h ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // xtts_v2_full.h - Full XTTS v2 GGUF Implementation
2
+ #ifndef XTTS_V2_FULL_H
3
+ #define XTTS_V2_FULL_H
4
+
5
+ #include <ggml.h>
6
+ #include <ggml-alloc.h>
7
+ #include <ggml-backend.h>
8
+ #include <cstdint>
9
+ #include <string>
10
+ #include <vector>
11
+ #include <memory>
12
+ #include <unordered_map>
13
+
14
+ namespace xtts_v2 {
15
+
16
+ // XTTS v2 Architecture Constants
17
+ struct XTTSConfig {
18
+ // GPT-2 Text Encoder
19
+ static constexpr int GPT_N_VOCAB = 6681; // BPE vocab size
20
+ static constexpr int GPT_N_CTX = 402; // Max context length
21
+ static constexpr int GPT_N_EMBD = 1024; // Hidden dimension
22
+ static constexpr int GPT_N_HEAD = 16; // Attention heads
23
+ static constexpr int GPT_N_LAYER = 30; // Transformer layers
24
+ static constexpr int GPT_INTERMEDIATE = 4096; // FFN intermediate size
25
+
26
+ // Latent Diffusion Decoder
27
+ static constexpr int LATENT_DIM = 1024; // Latent vector dimension
28
+ static constexpr int MEL_CHANNELS = 80; // Mel spectrogram bins
29
+ static constexpr int DECODER_LAYERS = 12; // Decoder depth
30
+ static constexpr int REF_ENCODER_LAYERS = 6; // Reference encoder layers
31
+
32
+ // HiFi-GAN Vocoder
33
+ static constexpr int HIFIGAN_UPSAMPLE_RATES[4] = {8, 8, 2, 2};
34
+ static constexpr int HIFIGAN_KERNEL_SIZES[4] = {16, 16, 4, 4};
35
+ static constexpr int HIFIGAN_CHANNELS = 512;
36
+ static constexpr int HIFIGAN_RESBLOCK_KERNELS[3] = {3, 7, 11};
37
+ static constexpr int HIFIGAN_RESBLOCK_DILATIONS[3][3] = {
38
+ {1, 3, 5}, {1, 3, 5}, {1, 3, 5}
39
+ };
40
+
41
+ // Audio settings
42
+ static constexpr int SAMPLE_RATE = 24000;
43
+ static constexpr int HOP_LENGTH = 256;
44
+ static constexpr int WIN_LENGTH = 1024;
45
+
46
+ // Languages (17 supported)
47
+ static constexpr int N_LANGUAGES = 17;
48
+ static constexpr int SPEAKER_EMBEDDING_DIM = 512;
49
+
50
+ // Conditioning
51
+ static constexpr int COND_LATENT_DIM = 1024;
52
+ static constexpr int MAX_MEL_LENGTH = 605;
53
+ static constexpr int MAX_AUDIO_LENGTH = 155520; // ~6.5 seconds @ 24kHz
54
+ };
55
+
56
+ // Full XTTS v2 Model Components
57
+ struct XTTSv2Model {
58
+ // Text Encoder (GPT-2 style)
59
+ struct TextEncoder {
60
+ ggml_tensor* wte; // Token embeddings [n_vocab, n_embd]
61
+ ggml_tensor* wpe; // Position embeddings [n_ctx, n_embd]
62
+
63
+ // Per-layer components
64
+ struct Layer {
65
+ // Attention
66
+ ggml_tensor* ln1_g; // LayerNorm gain
67
+ ggml_tensor* ln1_b; // LayerNorm bias
68
+ ggml_tensor* attn_qkv; // Combined QKV projection
69
+ ggml_tensor* attn_proj; // Output projection
70
+
71
+ // FFN
72
+ ggml_tensor* ln2_g; // LayerNorm gain
73
+ ggml_tensor* ln2_b; // LayerNorm bias
74
+ ggml_tensor* ffn_fc1; // FFN first layer
75
+ ggml_tensor* ffn_fc2; // FFN second layer
76
+ };
77
+ std::vector<Layer> layers;
78
+
79
+ ggml_tensor* ln_f_g; // Final LayerNorm gain
80
+ ggml_tensor* ln_f_b; // Final LayerNorm bias
81
+ } text_encoder;
82
+
83
+ // Reference Encoder (for voice cloning)
84
+ struct ReferenceEncoder {
85
+ ggml_tensor* mel_conv1; // Initial mel convolution
86
+
87
+ struct ConvBlock {
88
+ ggml_tensor* conv;
89
+ ggml_tensor* norm_g;
90
+ ggml_tensor* norm_b;
91
+ };
92
+ std::vector<ConvBlock> conv_blocks;
93
+
94
+ ggml_tensor* gru_ih; // GRU input-hidden weights
95
+ ggml_tensor* gru_hh; // GRU hidden-hidden weights
96
+ ggml_tensor* gru_bias; // GRU bias
97
+
98
+ ggml_tensor* speaker_proj; // Project to speaker embedding
99
+ } ref_encoder;
100
+
101
+ // Latent Diffusion Decoder
102
+ struct LatentDecoder {
103
+ ggml_tensor* latent_proj; // Project latents to hidden
104
+
105
+ struct DecoderLayer {
106
+ // Self-attention
107
+ ggml_tensor* sa_ln_g;
108
+ ggml_tensor* sa_ln_b;
109
+ ggml_tensor* sa_qkv;
110
+ ggml_tensor* sa_proj;
111
+
112
+ // Cross-attention (to text)
113
+ ggml_tensor* ca_ln_g;
114
+ ggml_tensor* ca_ln_b;
115
+ ggml_tensor* ca_q;
116
+ ggml_tensor* ca_kv;
117
+ ggml_tensor* ca_proj;
118
+
119
+ // FFN
120
+ ggml_tensor* ffn_ln_g;
121
+ ggml_tensor* ffn_ln_b;
122
+ ggml_tensor* ffn_fc1;
123
+ ggml_tensor* ffn_fc2;
124
+ };
125
+ std::vector<DecoderLayer> layers;
126
+
127
+ ggml_tensor* mel_head; // Project to mel spectrogram
128
+ ggml_tensor* stop_head; // Predict stop token
129
+ } decoder;
130
+
131
+ // HiFi-GAN Vocoder
132
+ struct Vocoder {
133
+ ggml_tensor* conv_pre; // Pre-conv [80, 512, 7]
134
+
135
+ struct UpsampleBlock {
136
+ ggml_tensor* conv_transpose; // Transposed convolution
137
+
138
+ struct ResBlock {
139
+ ggml_tensor* conv1;
140
+ ggml_tensor* conv2;
141
+ };
142
+ std::vector<ResBlock> res_blocks;
143
+ };
144
+ std::vector<UpsampleBlock> upsample_blocks;
145
+
146
+ ggml_tensor* conv_post; // Post-conv [512, 1, 7]
147
+ } vocoder;
148
+
149
+ // Conditioning layers
150
+ struct Conditioning {
151
+ ggml_tensor* speaker_embedding; // Speaker lookup table
152
+ ggml_tensor* language_embedding; // Language embeddings
153
+ ggml_tensor* style_embedding; // Style tokens (optional)
154
+ } conditioning;
155
+
156
+ // Model context
157
+ ggml_context* ctx = nullptr;
158
+ ggml_backend_t backend = nullptr;
159
+ ggml_backend_buffer_t buffer = nullptr;
160
+ size_t buffer_size = 0;
161
+ };
162
+
163
+ // KV Cache for autoregressive generation
164
+ struct XTTSKVCache {
165
+ // Text encoder cache
166
+ struct {
167
+ ggml_tensor* k[30]; // K cache per layer
168
+ ggml_tensor* v[30]; // V cache per layer
169
+ int n_cached = 0;
170
+ } text_cache;
171
+
172
+ // Decoder cache
173
+ struct {
174
+ ggml_tensor* k[12]; // K cache per layer
175
+ ggml_tensor* v[12]; // V cache per layer
176
+ ggml_tensor* cross_k[12]; // Cross-attention K cache
177
+ ggml_tensor* cross_v[12]; // Cross-attention V cache
178
+ int n_cached = 0;
179
+ } decoder_cache;
180
+ };
181
+
182
+ // Main XTTS v2 Inference Engine
183
+ class XTTSv2Inference {
184
+ public:
185
+ XTTSv2Inference();
186
+ ~XTTSv2Inference();
187
+
188
+ // Load model from GGUF file
189
+ bool load_model(const std::string& model_path, bool use_mmap = true);
190
+
191
+ // High-level TTS interface
192
+ std::vector<float> synthesize(
193
+ const std::string& text,
194
+ const std::string& language = "en",
195
+ const std::vector<float>& speaker_wav = {}, // Optional reference audio
196
+ float temperature = 0.65f,
197
+ float length_penalty = 1.0f,
198
+ float repetition_penalty = 2.0f,
199
+ float top_k = 50,
200
+ float top_p = 0.85f,
201
+ float speed = 1.0f
202
+ );
203
+
204
+ // Component-wise inference (for debugging/testing)
205
+ struct InferenceComponents {
206
+ std::vector<int32_t> tokens; // BPE tokens
207
+ ggml_tensor* text_embeddings; // Text encoder output
208
+ ggml_tensor* speaker_embedding; // Speaker embedding
209
+ ggml_tensor* latents; // Decoder latents
210
+ ggml_tensor* mel_spectrogram; // Generated mel
211
+ std::vector<float> audio; // Final audio
212
+ };
213
+
214
+ InferenceComponents synthesize_components(
215
+ const std::string& text,
216
+ const std::string& language = "en",
217
+ const std::vector<float>& speaker_wav = {}
218
+ );
219
+
220
+ // Streaming interface
221
+ class Stream {
222
+ public:
223
+ Stream(XTTSv2Inference* parent, const std::string& text,
224
+ const std::string& language, const std::vector<float>& speaker_wav);
225
+ ~Stream();
226
+
227
+ std::vector<float> get_chunk(size_t max_samples = 4096);
228
+ bool is_done() const { return done; }
229
+
230
+ private:
231
+ XTTSv2Inference* parent;
232
+ InferenceComponents components;
233
+ size_t mel_offset = 0;
234
+ size_t audio_offset = 0;
235
+ bool done = false;
236
+
237
+ void generate_next_mel_chunk();
238
+ std::vector<float> vocoder_chunk(ggml_tensor* mel_chunk);
239
+ };
240
+
241
+ std::unique_ptr<Stream> create_stream(
242
+ const std::string& text,
243
+ const std::string& language = "en",
244
+ const std::vector<float>& speaker_wav = {}
245
+ );
246
+
247
+ private:
248
+ XTTSConfig config;
249
+ XTTSv2Model model;
250
+ XTTSKVCache kv_cache;
251
+
252
+ // GGUF file handling
253
+ struct gguf_context* gguf_ctx = nullptr;
254
+ void* mapped_memory = nullptr;
255
+ size_t mapped_size = 0;
256
+
257
+ // Computation graph
258
+ ggml_cgraph* gf = nullptr;
259
+ ggml_gallocr* allocr = nullptr;
260
+
261
+ // Tokenizer
262
+ std::unordered_map<std::string, int32_t> bpe_vocab;
263
+ std::vector<std::pair<std::string, std::string>> bpe_merges;
264
+
265
+ // Internal methods
266
+ bool load_gguf_weights(const std::string& path, bool use_mmap);
267
+ void init_model_architecture();
268
+
269
+ // Text processing
270
+ std::vector<int32_t> tokenize(const std::string& text);
271
+ std::vector<std::string> bpe_encode(const std::string& text);
272
+
273
+ // Model forward passes
274
+ ggml_tensor* text_encoder_forward(
275
+ const std::vector<int32_t>& tokens,
276
+ const std::string& language
277
+ );
278
+
279
+ ggml_tensor* reference_encoder_forward(
280
+ const std::vector<float>& audio_wav
281
+ );
282
+
283
+ ggml_tensor* decoder_forward(
284
+ ggml_tensor* text_embeddings,
285
+ ggml_tensor* speaker_embedding,
286
+ float temperature,
287
+ float length_penalty
288
+ );
289
+
290
+ std::vector<float> vocoder_forward(
291
+ ggml_tensor* mel_spectrogram
292
+ );
293
+
294
+ // Attention mechanisms
295
+ ggml_tensor* multi_head_attention(
296
+ ggml_tensor* q, ggml_tensor* k, ggml_tensor* v,
297
+ int n_heads, bool use_cache = true
298
+ );
299
+
300
+ ggml_tensor* cross_attention(
301
+ ggml_tensor* queries,
302
+ ggml_tensor* keys,
303
+ ggml_tensor* values,
304
+ int n_heads
305
+ );
306
+
307
+ // Helper functions
308
+ ggml_tensor* layer_norm(ggml_tensor* x, ggml_tensor* g, ggml_tensor* b, float eps = 1e-5f);
309
+ ggml_tensor* gelu(ggml_tensor* x);
310
+ ggml_tensor* conv1d(ggml_tensor* x, ggml_tensor* w, ggml_tensor* b, int stride, int padding);
311
+ ggml_tensor* conv_transpose1d(ggml_tensor* x, ggml_tensor* w, ggml_tensor* b, int stride, int padding);
312
+
313
+ // Sampling
314
+ std::vector<int32_t> sample_latents(
315
+ ggml_tensor* logits,
316
+ float temperature,
317
+ float top_k,
318
+ float top_p,
319
+ float repetition_penalty
320
+ );
321
+ };
322
+
323
+ // NEON-optimized kernels for ARM
324
+ namespace kernels {
325
+ #ifdef __ARM_NEON
326
+
327
+ void gemm_q4_neon(
328
+ const uint8_t* a_q4,
329
+ const float* b,
330
+ float* c,
331
+ int m, int k, int n,
332
+ const float* scales
333
+ );
334
+
335
+ void conv1d_q8_neon(
336
+ const uint8_t* input_q8,
337
+ const uint8_t* kernel_q8,
338
+ float* output,
339
+ int batch, int in_c, int out_c,
340
+ int length, int kernel_size,
341
+ int stride, int padding,
342
+ const float* input_scale,
343
+ const float* kernel_scale
344
+ );
345
+
346
+ void attention_q4_neon(
347
+ const uint8_t* q_q4,
348
+ const uint8_t* k_q4,
349
+ const uint8_t* v_q4,
350
+ float* output,
351
+ int seq_len, int n_heads, int head_dim,
352
+ const float* q_scale,
353
+ const float* k_scale,
354
+ const float* v_scale
355
+ );
356
+
357
+ #endif // __ARM_NEON
358
+ } // namespace kernels
359
+
360
+ // C API for React Native / FFI
361
+ extern "C" {
362
+ void* xtts_v2_init(const char* model_path, bool use_mmap);
363
+
364
+ float* xtts_v2_synthesize(
365
+ void* model,
366
+ const char* text,
367
+ const char* language,
368
+ const float* speaker_wav,
369
+ size_t speaker_wav_len,
370
+ float temperature,
371
+ float speed,
372
+ size_t* out_len
373
+ );
374
+
375
+ void* xtts_v2_stream_init(
376
+ void* model,
377
+ const char* text,
378
+ const char* language,
379
+ const float* speaker_wav,
380
+ size_t speaker_wav_len
381
+ );
382
+
383
+ float* xtts_v2_stream_chunk(
384
+ void* stream,
385
+ size_t chunk_size,
386
+ size_t* out_len
387
+ );
388
+
389
+ void xtts_v2_stream_free(void* stream);
390
+ void xtts_v2_free(void* model);
391
+ void xtts_v2_free_audio(float* audio);
392
+ }
393
+
394
+ } // namespace xtts_v2
395
+
396
+ #endif // XTTS_V2_FULL_H