|
|
|
|
|
#include <jsi/jsi.h> |
|
|
#include <ReactCommon/TurboModule.h> |
|
|
#include <ReactCommon/CallInvoker.h> |
|
|
#include "../cpp/xtts_inference.h" |
|
|
#include <memory> |
|
|
#include <thread> |
|
|
#include <queue> |
|
|
#include <mutex> |
|
|
#include <condition_variable> |
|
|
|
|
|
using namespace facebook; |
|
|
|
|
|
namespace xtts_rn { |
|
|
|
|
|
|
|
|
class XTTSModule : public react::TurboModule { |
|
|
public: |
|
|
static constexpr auto kModuleName = "XTTSModule"; |
|
|
|
|
|
explicit XTTSModule(std::shared_ptr<react::CallInvoker> jsInvoker) |
|
|
: TurboModule(kModuleName, jsInvoker) { |
|
|
} |
|
|
|
|
|
~XTTSModule() { |
|
|
cleanup(); |
|
|
} |
|
|
|
|
|
|
|
|
jsi::Value initialize( |
|
|
jsi::Runtime& runtime, |
|
|
const jsi::String& modelPath, |
|
|
const jsi::Value& options |
|
|
) { |
|
|
std::string path = modelPath.utf8(runtime); |
|
|
bool use_mmap = true; |
|
|
bool use_gpu = false; |
|
|
int n_threads = 4; |
|
|
|
|
|
|
|
|
if (options.isObject()) { |
|
|
auto opts = options.asObject(runtime); |
|
|
|
|
|
if (opts.hasProperty(runtime, "useMmap")) { |
|
|
use_mmap = opts.getProperty(runtime, "useMmap").getBool(); |
|
|
} |
|
|
if (opts.hasProperty(runtime, "useGPU")) { |
|
|
use_gpu = opts.getProperty(runtime, "useGPU").getBool(); |
|
|
} |
|
|
if (opts.hasProperty(runtime, "threads")) { |
|
|
n_threads = static_cast<int>( |
|
|
opts.getProperty(runtime, "threads").getNumber() |
|
|
); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
cleanup(); |
|
|
|
|
|
|
|
|
model_ptr = xtts::xtts_init(path.c_str(), use_mmap); |
|
|
|
|
|
if (!model_ptr) { |
|
|
return jsi::Value(false); |
|
|
} |
|
|
|
|
|
|
|
|
auto* model = static_cast<xtts::XTTSInference*>(model_ptr); |
|
|
auto params = model->get_params(); |
|
|
|
|
|
|
|
|
auto info = jsi::Object(runtime); |
|
|
info.setProperty(runtime, "initialized", jsi::Value(true)); |
|
|
info.setProperty(runtime, "sampleRate", jsi::Value(params.sample_rate)); |
|
|
info.setProperty(runtime, "nLanguages", jsi::Value(params.n_languages)); |
|
|
info.setProperty(runtime, "memoryMB", |
|
|
jsi::Value(static_cast<double>(model->get_memory_usage()) / (1024*1024)) |
|
|
); |
|
|
|
|
|
return info; |
|
|
} |
|
|
|
|
|
|
|
|
jsi::Value generate( |
|
|
jsi::Runtime& runtime, |
|
|
const jsi::String& text, |
|
|
const jsi::Value& options |
|
|
) { |
|
|
if (!model_ptr) { |
|
|
throw jsi::JSError(runtime, "Model not initialized"); |
|
|
} |
|
|
|
|
|
std::string text_str = text.utf8(runtime); |
|
|
int language = 0; |
|
|
int speaker_id = 0; |
|
|
float temperature = 0.8f; |
|
|
float speed = 1.0f; |
|
|
|
|
|
|
|
|
if (options.isObject()) { |
|
|
auto opts = options.asObject(runtime); |
|
|
|
|
|
if (opts.hasProperty(runtime, "language")) { |
|
|
auto lang = opts.getProperty(runtime, "language").asString(runtime).utf8(runtime); |
|
|
language = languageFromString(lang); |
|
|
} |
|
|
if (opts.hasProperty(runtime, "speaker")) { |
|
|
speaker_id = static_cast<int>( |
|
|
opts.getProperty(runtime, "speaker").getNumber() |
|
|
); |
|
|
} |
|
|
if (opts.hasProperty(runtime, "temperature")) { |
|
|
temperature = static_cast<float>( |
|
|
opts.getProperty(runtime, "temperature").getNumber() |
|
|
); |
|
|
} |
|
|
if (opts.hasProperty(runtime, "speed")) { |
|
|
speed = static_cast<float>( |
|
|
opts.getProperty(runtime, "speed").getNumber() |
|
|
); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
size_t audio_length = 0; |
|
|
float* audio_data = xtts::xtts_generate( |
|
|
model_ptr, |
|
|
text_str.c_str(), |
|
|
language, |
|
|
speaker_id, |
|
|
temperature, |
|
|
speed, |
|
|
&audio_length |
|
|
); |
|
|
|
|
|
if (!audio_data) { |
|
|
return jsi::Value::null(); |
|
|
} |
|
|
|
|
|
|
|
|
auto audio_array = jsi::Array(runtime, audio_length); |
|
|
for (size_t i = 0; i < audio_length; ++i) { |
|
|
audio_array.setValueAtIndex(runtime, i, jsi::Value(audio_data[i])); |
|
|
} |
|
|
|
|
|
|
|
|
xtts::xtts_free_audio(audio_data); |
|
|
|
|
|
return audio_array; |
|
|
} |
|
|
|
|
|
|
|
|
jsi::Value generateAsync( |
|
|
jsi::Runtime& runtime, |
|
|
const jsi::String& text, |
|
|
const jsi::Value& options |
|
|
) { |
|
|
auto promise = runtime.global() |
|
|
.getPropertyAsFunction(runtime, "Promise") |
|
|
.callAsConstructor( |
|
|
runtime, |
|
|
jsi::Function::createFromHostFunction( |
|
|
runtime, |
|
|
jsi::PropNameID::forAscii(runtime, "executor"), |
|
|
2, |
|
|
[this, text, options]( |
|
|
jsi::Runtime& rt, |
|
|
const jsi::Value& thisValue, |
|
|
const jsi::Value* args, |
|
|
size_t count |
|
|
) -> jsi::Value { |
|
|
auto resolve = std::make_shared<jsi::Function>( |
|
|
args[0].asObject(rt).asFunction(rt) |
|
|
); |
|
|
auto reject = std::make_shared<jsi::Function>( |
|
|
args[1].asObject(rt).asFunction(rt) |
|
|
); |
|
|
|
|
|
|
|
|
std::string text_str = text.utf8(rt); |
|
|
int language = 0; |
|
|
int speaker_id = 0; |
|
|
float temperature = 0.8f; |
|
|
float speed = 1.0f; |
|
|
|
|
|
if (options.isObject()) { |
|
|
auto opts = options.asObject(rt); |
|
|
if (opts.hasProperty(rt, "language")) { |
|
|
auto lang = opts.getProperty(rt, "language") |
|
|
.asString(rt).utf8(rt); |
|
|
language = languageFromString(lang); |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::thread([ |
|
|
this, |
|
|
resolve, |
|
|
reject, |
|
|
text_str, |
|
|
language, |
|
|
speaker_id, |
|
|
temperature, |
|
|
speed |
|
|
]() { |
|
|
if (!model_ptr) { |
|
|
jsInvoker_->invokeAsync([reject]() { |
|
|
|
|
|
}); |
|
|
return; |
|
|
} |
|
|
|
|
|
size_t audio_length = 0; |
|
|
float* audio_data = xtts::xtts_generate( |
|
|
model_ptr, |
|
|
text_str.c_str(), |
|
|
language, |
|
|
speaker_id, |
|
|
temperature, |
|
|
speed, |
|
|
&audio_length |
|
|
); |
|
|
|
|
|
if (!audio_data) { |
|
|
jsInvoker_->invokeAsync([reject]() { |
|
|
|
|
|
}); |
|
|
return; |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<float> audio_vec( |
|
|
audio_data, |
|
|
audio_data + audio_length |
|
|
); |
|
|
xtts::xtts_free_audio(audio_data); |
|
|
|
|
|
|
|
|
jsInvoker_->invokeAsync([resolve, audio_vec]() { |
|
|
|
|
|
|
|
|
}); |
|
|
}).detach(); |
|
|
|
|
|
return jsi::Value::undefined(); |
|
|
} |
|
|
) |
|
|
); |
|
|
|
|
|
return promise; |
|
|
} |
|
|
|
|
|
|
|
|
jsi::Value createStream( |
|
|
jsi::Runtime& runtime, |
|
|
const jsi::String& text, |
|
|
const jsi::Value& options |
|
|
) { |
|
|
if (!model_ptr) { |
|
|
throw jsi::JSError(runtime, "Model not initialized"); |
|
|
} |
|
|
|
|
|
std::string text_str = text.utf8(runtime); |
|
|
int language = 0; |
|
|
|
|
|
if (options.isObject()) { |
|
|
auto opts = options.asObject(runtime); |
|
|
if (opts.hasProperty(runtime, "language")) { |
|
|
auto lang = opts.getProperty(runtime, "language") |
|
|
.asString(runtime).utf8(runtime); |
|
|
language = languageFromString(lang); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void* stream = xtts::xtts_stream_init( |
|
|
model_ptr, |
|
|
text_str.c_str(), |
|
|
language |
|
|
); |
|
|
|
|
|
if (!stream) { |
|
|
return jsi::Value::null(); |
|
|
} |
|
|
|
|
|
|
|
|
size_t stream_id = next_stream_id++; |
|
|
active_streams[stream_id] = stream; |
|
|
|
|
|
auto stream_obj = jsi::Object(runtime); |
|
|
stream_obj.setProperty(runtime, "id", jsi::Value(static_cast<double>(stream_id))); |
|
|
stream_obj.setProperty(runtime, "active", jsi::Value(true)); |
|
|
|
|
|
return stream_obj; |
|
|
} |
|
|
|
|
|
|
|
|
jsi::Value getStreamChunk( |
|
|
jsi::Runtime& runtime, |
|
|
const jsi::Value& streamHandle, |
|
|
const jsi::Value& chunkSize |
|
|
) { |
|
|
if (!streamHandle.isObject()) { |
|
|
throw jsi::JSError(runtime, "Invalid stream handle"); |
|
|
} |
|
|
|
|
|
auto handle = streamHandle.asObject(runtime); |
|
|
if (!handle.hasProperty(runtime, "id")) { |
|
|
throw jsi::JSError(runtime, "Stream handle missing id"); |
|
|
} |
|
|
|
|
|
size_t stream_id = static_cast<size_t>( |
|
|
handle.getProperty(runtime, "id").getNumber() |
|
|
); |
|
|
|
|
|
auto it = active_streams.find(stream_id); |
|
|
if (it == active_streams.end()) { |
|
|
return jsi::Value::null(); |
|
|
} |
|
|
|
|
|
size_t chunk_samples = 8192; |
|
|
if (chunkSize.isNumber()) { |
|
|
chunk_samples = static_cast<size_t>(chunkSize.getNumber()); |
|
|
} |
|
|
|
|
|
size_t audio_length = 0; |
|
|
float* audio_data = xtts::xtts_stream_next( |
|
|
it->second, |
|
|
chunk_samples, |
|
|
&audio_length |
|
|
); |
|
|
|
|
|
if (!audio_data || audio_length == 0) { |
|
|
|
|
|
handle.setProperty(runtime, "active", jsi::Value(false)); |
|
|
return jsi::Value::null(); |
|
|
} |
|
|
|
|
|
|
|
|
auto audio_array = jsi::Array(runtime, audio_length); |
|
|
for (size_t i = 0; i < audio_length; ++i) { |
|
|
audio_array.setValueAtIndex(runtime, i, jsi::Value(audio_data[i])); |
|
|
} |
|
|
|
|
|
xtts::xtts_free_audio(audio_data); |
|
|
|
|
|
return audio_array; |
|
|
} |
|
|
|
|
|
|
|
|
jsi::Value closeStream( |
|
|
jsi::Runtime& runtime, |
|
|
const jsi::Value& streamHandle |
|
|
) { |
|
|
if (!streamHandle.isObject()) { |
|
|
return jsi::Value(false); |
|
|
} |
|
|
|
|
|
auto handle = streamHandle.asObject(runtime); |
|
|
if (!handle.hasProperty(runtime, "id")) { |
|
|
return jsi::Value(false); |
|
|
} |
|
|
|
|
|
size_t stream_id = static_cast<size_t>( |
|
|
handle.getProperty(runtime, "id").getNumber() |
|
|
); |
|
|
|
|
|
auto it = active_streams.find(stream_id); |
|
|
if (it != active_streams.end()) { |
|
|
xtts::xtts_stream_free(it->second); |
|
|
active_streams.erase(it); |
|
|
return jsi::Value(true); |
|
|
} |
|
|
|
|
|
return jsi::Value(false); |
|
|
} |
|
|
|
|
|
|
|
|
jsi::Value getSupportedLanguages(jsi::Runtime& runtime) { |
|
|
auto languages = jsi::Array(runtime, 17); |
|
|
const char* lang_codes[] = { |
|
|
"en", "es", "fr", "de", "it", "pt", "pl", "tr", |
|
|
"ru", "nl", "cs", "ar", "zh", "ja", "ko", "hu", "hi" |
|
|
}; |
|
|
|
|
|
for (int i = 0; i < 17; ++i) { |
|
|
languages.setValueAtIndex( |
|
|
runtime, i, |
|
|
jsi::String::createFromUtf8(runtime, lang_codes[i]) |
|
|
); |
|
|
} |
|
|
|
|
|
return languages; |
|
|
} |
|
|
|
|
|
|
|
|
jsi::Value cleanup(jsi::Runtime& runtime) { |
|
|
cleanup(); |
|
|
return jsi::Value(true); |
|
|
} |
|
|
|
|
|
private: |
|
|
void* model_ptr = nullptr; |
|
|
std::map<size_t, void*> active_streams; |
|
|
size_t next_stream_id = 1; |
|
|
|
|
|
void cleanup() { |
|
|
|
|
|
for (auto& [id, stream] : active_streams) { |
|
|
xtts::xtts_stream_free(stream); |
|
|
} |
|
|
active_streams.clear(); |
|
|
|
|
|
|
|
|
if (model_ptr) { |
|
|
xtts::xtts_free(model_ptr); |
|
|
model_ptr = nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
int languageFromString(const std::string& lang) { |
|
|
static const std::map<std::string, int> lang_map = { |
|
|
{"en", 0}, {"es", 1}, {"fr", 2}, {"de", 3}, |
|
|
{"it", 4}, {"pt", 5}, {"pl", 6}, {"tr", 7}, |
|
|
{"ru", 8}, {"nl", 9}, {"cs", 10}, {"ar", 11}, |
|
|
{"zh", 12}, {"ja", 13}, {"ko", 14}, {"hu", 15}, {"hi", 16} |
|
|
}; |
|
|
|
|
|
auto it = lang_map.find(lang); |
|
|
return it != lang_map.end() ? it->second : 0; |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
std::shared_ptr<react::TurboModule> XTTSModuleProvider( |
|
|
std::shared_ptr<react::CallInvoker> jsInvoker |
|
|
) { |
|
|
return std::make_shared<XTTSModule>(jsInvoker); |
|
|
} |
|
|
|
|
|
} |