xtts-gguf / cpp /xtts_inference.h
bnewton-genmedlabs's picture
Initial GGUF implementation with C++ inference engine
4688879 verified
// xtts_inference.h - XTTS GGUF Inference Engine Header
#ifndef XTTS_INFERENCE_H
#define XTTS_INFERENCE_H
#include <ggml.h>
#include <ggml-alloc.h>
#include <ggml-backend.h>
#include <cstdint>
#include <string>
#include <vector>
#include <memory>
#include <unordered_map>
namespace xtts {
// Model hyperparameters matching XTTS v2
struct XTTSHyperParams {
int32_t n_vocab = 256; // Byte-level vocabulary
int32_t n_ctx_text = 402; // Max text context
int32_t n_ctx_audio = 605; // Max audio context
int32_t n_embd = 1024; // Embedding dimension
int32_t n_head = 16; // Number of attention heads
int32_t n_layer = 24; // Number of GPT layers
int32_t n_mel_channels = 80; // Mel spectrogram channels
int32_t n_audio_tokens = 1026; // Audio codebook size
int32_t sample_rate = 24000; // Audio sample rate
int32_t n_languages = 17; // Number of supported languages
int32_t speaker_emb_dim = 512; // Speaker embedding dimension
};
// Language mapping
enum Language {
LANG_EN = 0, // English
LANG_ES = 1, // Spanish
LANG_FR = 2, // French
LANG_DE = 3, // German
LANG_IT = 4, // Italian
LANG_PT = 5, // Portuguese
LANG_PL = 6, // Polish
LANG_TR = 7, // Turkish
LANG_RU = 8, // Russian
LANG_NL = 9, // Dutch
LANG_CS = 10, // Czech
LANG_AR = 11, // Arabic
LANG_ZH = 12, // Chinese
LANG_JA = 13, // Japanese
LANG_KO = 14, // Korean
LANG_HU = 15, // Hungarian
LANG_HI = 16 // Hindi
};
// Forward declarations
struct ggml_context;
struct ggml_tensor;
struct gguf_context;
// XTTS Model weights structure
struct XTTSModel {
// Text encoder
struct ggml_tensor* text_embedding; // [n_vocab, n_embd]
struct ggml_tensor* language_embedding; // [n_languages, n_embd]
struct ggml_tensor* pos_encoding; // [n_ctx_text, n_embd]
// GPT layers
std::vector<struct ggml_tensor*> ln1_weight; // Layer norm 1 weights
std::vector<struct ggml_tensor*> ln1_bias; // Layer norm 1 bias
std::vector<struct ggml_tensor*> attn_qkv; // Attention QKV projection
std::vector<struct ggml_tensor*> attn_out; // Attention output projection
std::vector<struct ggml_tensor*> ln2_weight; // Layer norm 2 weights
std::vector<struct ggml_tensor*> ln2_bias; // Layer norm 2 bias
std::vector<struct ggml_tensor*> ffn_up; // FFN up projection
std::vector<struct ggml_tensor*> ffn_down; // FFN down projection
// Audio token predictor
struct ggml_tensor* audio_token_predictor; // [n_embd, n_audio_tokens]
// Vocoder layers (simplified HiFi-GAN)
struct ggml_tensor* vocoder_preconv; // Initial convolution
std::vector<struct ggml_tensor*> vocoder_ups; // Upsampling layers
std::vector<struct ggml_tensor*> vocoder_resblocks; // Residual blocks
struct ggml_tensor* vocoder_postconv; // Final convolution
// Speaker embedding projection
struct ggml_tensor* speaker_projection; // [speaker_emb_dim, n_embd]
// Context and memory
struct ggml_context* ctx = nullptr;
ggml_backend_t backend = nullptr;
ggml_backend_buffer_t buffer = nullptr;
~XTTSModel();
};
// KV cache for autoregressive generation
struct XTTSKVCache {
struct ggml_tensor* k_cache; // [n_layer, n_ctx, n_embd]
struct ggml_tensor* v_cache; // [n_layer, n_ctx, n_embd]
int32_t n_cached = 0;
};
// Main XTTS inference class
class XTTSInference {
public:
XTTSInference();
~XTTSInference();
// Load model from GGUF file
bool load_model(const std::string& model_path, bool use_mmap = true);
// Generate speech from text
std::vector<float> generate(
const std::string& text,
Language language = LANG_EN,
int speaker_id = 0,
float temperature = 0.8f,
float speed = 1.0f
);
// Stream generation (for real-time synthesis)
class StreamGenerator {
public:
StreamGenerator(XTTSInference* parent, const std::string& text, Language lang);
~StreamGenerator();
// Get next audio chunk (returns empty when done)
std::vector<float> get_next_chunk(size_t chunk_samples = 8192);
bool is_done() const { return done; }
private:
XTTSInference* parent_model;
std::vector<int32_t> text_tokens;
std::vector<int32_t> audio_tokens;
Language language;
size_t current_token = 0;
bool done = false;
void generate_next_tokens(size_t n_tokens);
};
// Create a stream generator
std::unique_ptr<StreamGenerator> create_stream(
const std::string& text,
Language language = LANG_EN
);
// Get model info
XTTSHyperParams get_params() const { return hparams; }
size_t get_memory_usage() const;
private:
XTTSHyperParams hparams;
XTTSModel model;
XTTSKVCache kv_cache;
// Model file handle (for mmap)
struct gguf_context* gguf_ctx = nullptr;
void* mapped_memory = nullptr;
size_t mapped_size = 0;
// Computation graph
struct ggml_cgraph* gf = nullptr;
struct ggml_gallocr* allocr = nullptr;
// Internal methods
bool load_gguf_file(const std::string& path, bool use_mmap);
void create_computation_graph();
// Text processing
std::vector<int32_t> tokenize(const std::string& text);
// Model forward passes
struct ggml_tensor* encode_text(
const std::vector<int32_t>& tokens,
Language language,
const std::vector<float>& speaker_embedding
);
std::vector<int32_t> generate_audio_tokens(
struct ggml_tensor* text_features,
float temperature
);
std::vector<float> vocoder_forward(
const std::vector<int32_t>& audio_tokens
);
// Attention mechanism
struct ggml_tensor* attention(
struct ggml_tensor* x,
int layer_idx,
bool use_cache = true
);
// Feed-forward network
struct ggml_tensor* ffn(
struct ggml_tensor* x,
int layer_idx
);
// Utility functions
struct ggml_tensor* layer_norm(
struct ggml_tensor* x,
struct ggml_tensor* weight,
struct ggml_tensor* bias,
float eps = 1e-5f
);
int32_t sample_token(
struct ggml_tensor* logits,
float temperature,
float top_p = 0.9f
);
std::vector<float> create_speaker_embedding(int speaker_id);
};
// React Native bridge functions
extern "C" {
// Initialize model
void* xtts_init(const char* model_path, bool use_mmap);
// Generate speech
float* xtts_generate(
void* model_ptr,
const char* text,
int language,
int speaker_id,
float temperature,
float speed,
size_t* out_length
);
// Stream generation
void* xtts_stream_init(
void* model_ptr,
const char* text,
int language
);
float* xtts_stream_next(
void* stream_ptr,
size_t chunk_size,
size_t* out_length
);
void xtts_stream_free(void* stream_ptr);
// Cleanup
void xtts_free(void* model_ptr);
void xtts_free_audio(float* audio_ptr);
}
} // namespace xtts
#endif // XTTS_INFERENCE_H