xtts-gguf / react-native /XTTSModule.cpp
bnewton-genmedlabs's picture
Initial GGUF implementation with C++ inference engine
4688879 verified
// XTTSModule.cpp - React Native TurboModule for XTTS GGUF
#include <jsi/jsi.h>
#include <ReactCommon/TurboModule.h>
#include <ReactCommon/CallInvoker.h>
#include "../cpp/xtts_inference.h"
#include <memory>
#include <thread>
#include <queue>
#include <mutex>
#include <condition_variable>
using namespace facebook;
namespace xtts_rn {
// TurboModule implementation for XTTS
class XTTSModule : public react::TurboModule {
public:
static constexpr auto kModuleName = "XTTSModule";
explicit XTTSModule(std::shared_ptr<react::CallInvoker> jsInvoker)
: TurboModule(kModuleName, jsInvoker) {
}
~XTTSModule() {
cleanup();
}
// Initialize model from GGUF file
jsi::Value initialize(
jsi::Runtime& runtime,
const jsi::String& modelPath,
const jsi::Value& options
) {
std::string path = modelPath.utf8(runtime);
bool use_mmap = true;
bool use_gpu = false;
int n_threads = 4;
// Parse options
if (options.isObject()) {
auto opts = options.asObject(runtime);
if (opts.hasProperty(runtime, "useMmap")) {
use_mmap = opts.getProperty(runtime, "useMmap").getBool();
}
if (opts.hasProperty(runtime, "useGPU")) {
use_gpu = opts.getProperty(runtime, "useGPU").getBool();
}
if (opts.hasProperty(runtime, "threads")) {
n_threads = static_cast<int>(
opts.getProperty(runtime, "threads").getNumber()
);
}
}
// Clean up previous model if exists
cleanup();
// Initialize new model
model_ptr = xtts::xtts_init(path.c_str(), use_mmap);
if (!model_ptr) {
return jsi::Value(false);
}
// Get model info
auto* model = static_cast<xtts::XTTSInference*>(model_ptr);
auto params = model->get_params();
// Return model info
auto info = jsi::Object(runtime);
info.setProperty(runtime, "initialized", jsi::Value(true));
info.setProperty(runtime, "sampleRate", jsi::Value(params.sample_rate));
info.setProperty(runtime, "nLanguages", jsi::Value(params.n_languages));
info.setProperty(runtime, "memoryMB",
jsi::Value(static_cast<double>(model->get_memory_usage()) / (1024*1024))
);
return info;
}
// Generate speech synchronously
jsi::Value generate(
jsi::Runtime& runtime,
const jsi::String& text,
const jsi::Value& options
) {
if (!model_ptr) {
throw jsi::JSError(runtime, "Model not initialized");
}
std::string text_str = text.utf8(runtime);
int language = 0; // Default to English
int speaker_id = 0;
float temperature = 0.8f;
float speed = 1.0f;
// Parse options
if (options.isObject()) {
auto opts = options.asObject(runtime);
if (opts.hasProperty(runtime, "language")) {
auto lang = opts.getProperty(runtime, "language").asString(runtime).utf8(runtime);
language = languageFromString(lang);
}
if (opts.hasProperty(runtime, "speaker")) {
speaker_id = static_cast<int>(
opts.getProperty(runtime, "speaker").getNumber()
);
}
if (opts.hasProperty(runtime, "temperature")) {
temperature = static_cast<float>(
opts.getProperty(runtime, "temperature").getNumber()
);
}
if (opts.hasProperty(runtime, "speed")) {
speed = static_cast<float>(
opts.getProperty(runtime, "speed").getNumber()
);
}
}
// Generate audio
size_t audio_length = 0;
float* audio_data = xtts::xtts_generate(
model_ptr,
text_str.c_str(),
language,
speaker_id,
temperature,
speed,
&audio_length
);
if (!audio_data) {
return jsi::Value::null();
}
// Convert to JS array
auto audio_array = jsi::Array(runtime, audio_length);
for (size_t i = 0; i < audio_length; ++i) {
audio_array.setValueAtIndex(runtime, i, jsi::Value(audio_data[i]));
}
// Clean up
xtts::xtts_free_audio(audio_data);
return audio_array;
}
// Generate speech asynchronously with promise
jsi::Value generateAsync(
jsi::Runtime& runtime,
const jsi::String& text,
const jsi::Value& options
) {
auto promise = runtime.global()
.getPropertyAsFunction(runtime, "Promise")
.callAsConstructor(
runtime,
jsi::Function::createFromHostFunction(
runtime,
jsi::PropNameID::forAscii(runtime, "executor"),
2,
[this, text, options](
jsi::Runtime& rt,
const jsi::Value& thisValue,
const jsi::Value* args,
size_t count
) -> jsi::Value {
auto resolve = std::make_shared<jsi::Function>(
args[0].asObject(rt).asFunction(rt)
);
auto reject = std::make_shared<jsi::Function>(
args[1].asObject(rt).asFunction(rt)
);
// Capture parameters
std::string text_str = text.utf8(rt);
int language = 0;
int speaker_id = 0;
float temperature = 0.8f;
float speed = 1.0f;
if (options.isObject()) {
auto opts = options.asObject(rt);
if (opts.hasProperty(rt, "language")) {
auto lang = opts.getProperty(rt, "language")
.asString(rt).utf8(rt);
language = languageFromString(lang);
}
// Parse other options...
}
// Run generation in background thread
std::thread([
this,
resolve,
reject,
text_str,
language,
speaker_id,
temperature,
speed
]() {
if (!model_ptr) {
jsInvoker_->invokeAsync([reject]() {
// reject->call(rt, "Model not initialized");
});
return;
}
size_t audio_length = 0;
float* audio_data = xtts::xtts_generate(
model_ptr,
text_str.c_str(),
language,
speaker_id,
temperature,
speed,
&audio_length
);
if (!audio_data) {
jsInvoker_->invokeAsync([reject]() {
// reject->call(rt, "Generation failed");
});
return;
}
// Convert to vector for thread safety
std::vector<float> audio_vec(
audio_data,
audio_data + audio_length
);
xtts::xtts_free_audio(audio_data);
// Resolve on JS thread
jsInvoker_->invokeAsync([resolve, audio_vec]() {
// Create array and resolve
// This needs proper JSI context
});
}).detach();
return jsi::Value::undefined();
}
)
);
return promise;
}
// Stream generation
jsi::Value createStream(
jsi::Runtime& runtime,
const jsi::String& text,
const jsi::Value& options
) {
if (!model_ptr) {
throw jsi::JSError(runtime, "Model not initialized");
}
std::string text_str = text.utf8(runtime);
int language = 0;
if (options.isObject()) {
auto opts = options.asObject(runtime);
if (opts.hasProperty(runtime, "language")) {
auto lang = opts.getProperty(runtime, "language")
.asString(runtime).utf8(runtime);
language = languageFromString(lang);
}
}
// Create stream
void* stream = xtts::xtts_stream_init(
model_ptr,
text_str.c_str(),
language
);
if (!stream) {
return jsi::Value::null();
}
// Store stream pointer and return handle
size_t stream_id = next_stream_id++;
active_streams[stream_id] = stream;
auto stream_obj = jsi::Object(runtime);
stream_obj.setProperty(runtime, "id", jsi::Value(static_cast<double>(stream_id)));
stream_obj.setProperty(runtime, "active", jsi::Value(true));
return stream_obj;
}
// Get next chunk from stream
jsi::Value getStreamChunk(
jsi::Runtime& runtime,
const jsi::Value& streamHandle,
const jsi::Value& chunkSize
) {
if (!streamHandle.isObject()) {
throw jsi::JSError(runtime, "Invalid stream handle");
}
auto handle = streamHandle.asObject(runtime);
if (!handle.hasProperty(runtime, "id")) {
throw jsi::JSError(runtime, "Stream handle missing id");
}
size_t stream_id = static_cast<size_t>(
handle.getProperty(runtime, "id").getNumber()
);
auto it = active_streams.find(stream_id);
if (it == active_streams.end()) {
return jsi::Value::null();
}
size_t chunk_samples = 8192; // Default chunk size
if (chunkSize.isNumber()) {
chunk_samples = static_cast<size_t>(chunkSize.getNumber());
}
size_t audio_length = 0;
float* audio_data = xtts::xtts_stream_next(
it->second,
chunk_samples,
&audio_length
);
if (!audio_data || audio_length == 0) {
// Stream finished
handle.setProperty(runtime, "active", jsi::Value(false));
return jsi::Value::null();
}
// Convert to JS array
auto audio_array = jsi::Array(runtime, audio_length);
for (size_t i = 0; i < audio_length; ++i) {
audio_array.setValueAtIndex(runtime, i, jsi::Value(audio_data[i]));
}
xtts::xtts_free_audio(audio_data);
return audio_array;
}
// Close stream
jsi::Value closeStream(
jsi::Runtime& runtime,
const jsi::Value& streamHandle
) {
if (!streamHandle.isObject()) {
return jsi::Value(false);
}
auto handle = streamHandle.asObject(runtime);
if (!handle.hasProperty(runtime, "id")) {
return jsi::Value(false);
}
size_t stream_id = static_cast<size_t>(
handle.getProperty(runtime, "id").getNumber()
);
auto it = active_streams.find(stream_id);
if (it != active_streams.end()) {
xtts::xtts_stream_free(it->second);
active_streams.erase(it);
return jsi::Value(true);
}
return jsi::Value(false);
}
// Get supported languages
jsi::Value getSupportedLanguages(jsi::Runtime& runtime) {
auto languages = jsi::Array(runtime, 17);
const char* lang_codes[] = {
"en", "es", "fr", "de", "it", "pt", "pl", "tr",
"ru", "nl", "cs", "ar", "zh", "ja", "ko", "hu", "hi"
};
for (int i = 0; i < 17; ++i) {
languages.setValueAtIndex(
runtime, i,
jsi::String::createFromUtf8(runtime, lang_codes[i])
);
}
return languages;
}
// Release model resources
jsi::Value cleanup(jsi::Runtime& runtime) {
cleanup();
return jsi::Value(true);
}
private:
void* model_ptr = nullptr;
std::map<size_t, void*> active_streams;
size_t next_stream_id = 1;
void cleanup() {
// Close all active streams
for (auto& [id, stream] : active_streams) {
xtts::xtts_stream_free(stream);
}
active_streams.clear();
// Free model
if (model_ptr) {
xtts::xtts_free(model_ptr);
model_ptr = nullptr;
}
}
int languageFromString(const std::string& lang) {
static const std::map<std::string, int> lang_map = {
{"en", 0}, {"es", 1}, {"fr", 2}, {"de", 3},
{"it", 4}, {"pt", 5}, {"pl", 6}, {"tr", 7},
{"ru", 8}, {"nl", 9}, {"cs", 10}, {"ar", 11},
{"zh", 12}, {"ja", 13}, {"ko", 14}, {"hu", 15}, {"hi", 16}
};
auto it = lang_map.find(lang);
return it != lang_map.end() ? it->second : 0;
}
};
// Module provider
std::shared_ptr<react::TurboModule> XTTSModuleProvider(
std::shared_ptr<react::CallInvoker> jsInvoker
) {
return std::make_shared<XTTSModule>(jsInvoker);
}
} // namespace xtts_rn