xtts-gguf / react-native /XTTSModule.ts
bnewton-genmedlabs's picture
Initial GGUF implementation with C++ inference engine
4688879 verified
// XTTSModule.ts - TypeScript interface for XTTS React Native module
import { NativeModules, Platform } from 'react-native';
import RNFS from 'react-native-fs';
// Native module interface
interface XTTSNativeModule {
initialize(modelPath: string, options?: InitOptions): Promise<ModelInfo>;
generate(text: string, options?: GenerateOptions): Promise<Float32Array>;
generateAsync(text: string, options?: GenerateOptions): Promise<Float32Array>;
createStream(text: string, options?: StreamOptions): StreamHandle;
getStreamChunk(stream: StreamHandle, chunkSize?: number): Float32Array | null;
closeStream(stream: StreamHandle): boolean;
getSupportedLanguages(): string[];
cleanup(): boolean;
}
// Type definitions
export interface InitOptions {
useMmap?: boolean; // Use memory-mapped loading (default: true)
useGPU?: boolean; // Use GPU acceleration if available (default: false)
threads?: number; // Number of threads to use (default: 4)
}
export interface ModelInfo {
initialized: boolean;
sampleRate: number;
nLanguages: number;
memoryMB: number;
}
export interface GenerateOptions {
language?: string; // Language code (e.g., 'en', 'es', 'fr')
speaker?: number; // Speaker ID (0-9)
temperature?: number; // Sampling temperature (0.1-2.0, default: 0.8)
speed?: number; // Speech speed (0.5-2.0, default: 1.0)
}
export interface StreamOptions {
language?: string;
bufferSize?: number; // Audio buffer size in samples
}
export interface StreamHandle {
id: number;
active: boolean;
}
export type Language =
| 'en' | 'es' | 'fr' | 'de' | 'it' | 'pt' | 'pl' | 'tr'
| 'ru' | 'nl' | 'cs' | 'ar' | 'zh' | 'ja' | 'ko' | 'hu' | 'hi';
// Main XTTS class
export class XTTS {
private nativeModule: XTTSNativeModule;
private modelInfo: ModelInfo | null = null;
private modelPath: string | null = null;
constructor() {
const { XTTSModule } = NativeModules;
if (!XTTSModule) {
throw new Error(
'XTTSModule not found. Make sure the native module is properly linked.'
);
}
this.nativeModule = XTTSModule;
}
/**
* Download model from Hugging Face
*/
async downloadModel(
variant: 'q4_k' | 'q8' | 'f16' = 'q4_k',
progressCallback?: (progress: number) => void
): Promise<string> {
const HF_REPO = 'GenMedLabs/xtts-gguf';
const HF_BASE = `https://huggingface.co/${HF_REPO}/resolve/main`;
const modelFile = `gguf/xtts_v2_${variant}.gguf`;
const url = `${HF_BASE}/${modelFile}?download=true`;
const destPath = `${RNFS.DocumentDirectoryPath}/xtts_${variant}.gguf`;
// Check if model already exists
const exists = await RNFS.exists(destPath);
if (exists) {
console.log(`Model already downloaded at ${destPath}`);
return destPath;
}
console.log(`Downloading XTTS ${variant} model...`);
// Download with progress
const download = RNFS.downloadFile({
fromUrl: url,
toFile: destPath,
background: true,
discretionary: true,
progressDivider: 1,
progress: (res) => {
const progress = res.bytesWritten / res.contentLength;
progressCallback?.(progress);
},
});
const result = await download.promise;
if (result.statusCode !== 200) {
throw new Error(`Failed to download model: ${result.statusCode}`);
}
console.log(`Model downloaded to ${destPath}`);
return destPath;
}
/**
* Initialize the model from a local file
*/
async initialize(
modelPath?: string,
options?: InitOptions
): Promise<ModelInfo> {
// Use provided path or download default
if (!modelPath) {
modelPath = await this.downloadModel('q4_k');
}
// Verify file exists
const exists = await RNFS.exists(modelPath);
if (!exists) {
throw new Error(`Model file not found: ${modelPath}`);
}
// Get file info
const stat = await RNFS.stat(modelPath);
console.log(`Loading model: ${stat.size / (1024*1024)}MB`);
// Initialize native module
this.modelInfo = await this.nativeModule.initialize(modelPath, options);
this.modelPath = modelPath;
console.log(`Model initialized:`);
console.log(` Sample rate: ${this.modelInfo.sampleRate}Hz`);
console.log(` Languages: ${this.modelInfo.nLanguages}`);
console.log(` Memory usage: ${this.modelInfo.memoryMB}MB`);
return this.modelInfo;
}
/**
* Generate speech from text
*/
async speak(
text: string,
options?: GenerateOptions
): Promise<Float32Array> {
if (!this.modelInfo?.initialized) {
throw new Error('Model not initialized. Call initialize() first.');
}
// Validate options
if (options?.language && !this.isValidLanguage(options.language)) {
throw new Error(`Unsupported language: ${options.language}`);
}
// Generate audio
const audio = await this.nativeModule.generateAsync(text, options);
return audio;
}
/**
* Create a streaming generator
*/
createStream(
text: string,
options?: StreamOptions
): XTTSStream {
if (!this.modelInfo?.initialized) {
throw new Error('Model not initialized. Call initialize() first.');
}
const handle = this.nativeModule.createStream(text, options);
return new XTTSStream(this.nativeModule, handle);
}
/**
* Get supported languages
*/
getSupportedLanguages(): Language[] {
return this.nativeModule.getSupportedLanguages() as Language[];
}
/**
* Check if a language is supported
*/
isValidLanguage(lang: string): boolean {
const supported = this.getSupportedLanguages();
return supported.includes(lang as Language);
}
/**
* Get model information
*/
getModelInfo(): ModelInfo | null {
return this.modelInfo;
}
/**
* Clean up resources
*/
cleanup(): void {
this.nativeModule.cleanup();
this.modelInfo = null;
this.modelPath = null;
}
}
/**
* Streaming audio generation
*/
export class XTTSStream {
private nativeModule: XTTSNativeModule;
private handle: StreamHandle;
private audioBuffer: Float32Array[] = [];
private onDataCallback?: (chunk: Float32Array) => void;
private onEndCallback?: () => void;
private polling = false;
constructor(nativeModule: XTTSNativeModule, handle: StreamHandle) {
this.nativeModule = nativeModule;
this.handle = handle;
}
/**
* Set callback for audio data
*/
onData(callback: (chunk: Float32Array) => void): this {
this.onDataCallback = callback;
return this;
}
/**
* Set callback for stream end
*/
onEnd(callback: () => void): this {
this.onEndCallback = callback;
return this;
}
/**
* Start streaming
*/
start(): void {
if (this.polling) return;
this.polling = true;
this.pollForChunks();
}
/**
* Poll for audio chunks
*/
private async pollForChunks(): Promise<void> {
while (this.polling && this.handle.active) {
const chunk = this.nativeModule.getStreamChunk(this.handle, 8192);
if (chunk) {
this.audioBuffer.push(chunk);
this.onDataCallback?.(chunk);
} else {
// Stream ended
this.handle.active = false;
this.polling = false;
this.onEndCallback?.();
break;
}
// Small delay between polls
await new Promise(resolve => setTimeout(resolve, 10));
}
}
/**
* Stop streaming
*/
stop(): void {
this.polling = false;
this.nativeModule.closeStream(this.handle);
this.handle.active = false;
}
/**
* Get all buffered audio
*/
getBuffer(): Float32Array {
const totalLength = this.audioBuffer.reduce(
(sum, chunk) => sum + chunk.length, 0
);
const result = new Float32Array(totalLength);
let offset = 0;
for (const chunk of this.audioBuffer) {
result.set(chunk, offset);
offset += chunk.length;
}
return result;
}
/**
* Check if stream is active
*/
isActive(): boolean {
return this.handle.active;
}
}
// Default export
export default new XTTS();