|
|
|
|
|
import { NativeModules, Platform } from 'react-native'; |
|
|
import RNFS from 'react-native-fs'; |
|
|
|
|
|
|
|
|
interface XTTSNativeModule { |
|
|
initialize(modelPath: string, options?: InitOptions): Promise<ModelInfo>; |
|
|
generate(text: string, options?: GenerateOptions): Promise<Float32Array>; |
|
|
generateAsync(text: string, options?: GenerateOptions): Promise<Float32Array>; |
|
|
createStream(text: string, options?: StreamOptions): StreamHandle; |
|
|
getStreamChunk(stream: StreamHandle, chunkSize?: number): Float32Array | null; |
|
|
closeStream(stream: StreamHandle): boolean; |
|
|
getSupportedLanguages(): string[]; |
|
|
cleanup(): boolean; |
|
|
} |
|
|
|
|
|
|
|
|
export interface InitOptions { |
|
|
useMmap?: boolean; |
|
|
useGPU?: boolean; |
|
|
threads?: number; |
|
|
} |
|
|
|
|
|
export interface ModelInfo { |
|
|
initialized: boolean; |
|
|
sampleRate: number; |
|
|
nLanguages: number; |
|
|
memoryMB: number; |
|
|
} |
|
|
|
|
|
export interface GenerateOptions { |
|
|
language?: string; |
|
|
speaker?: number; |
|
|
temperature?: number; |
|
|
speed?: number; |
|
|
} |
|
|
|
|
|
export interface StreamOptions { |
|
|
language?: string; |
|
|
bufferSize?: number; |
|
|
} |
|
|
|
|
|
export interface StreamHandle { |
|
|
id: number; |
|
|
active: boolean; |
|
|
} |
|
|
|
|
|
export type Language = |
|
|
| 'en' | 'es' | 'fr' | 'de' | 'it' | 'pt' | 'pl' | 'tr' |
|
|
| 'ru' | 'nl' | 'cs' | 'ar' | 'zh' | 'ja' | 'ko' | 'hu' | 'hi'; |
|
|
|
|
|
|
|
|
export class XTTS { |
|
|
private nativeModule: XTTSNativeModule; |
|
|
private modelInfo: ModelInfo | null = null; |
|
|
private modelPath: string | null = null; |
|
|
|
|
|
constructor() { |
|
|
const { XTTSModule } = NativeModules; |
|
|
if (!XTTSModule) { |
|
|
throw new Error( |
|
|
'XTTSModule not found. Make sure the native module is properly linked.' |
|
|
); |
|
|
} |
|
|
this.nativeModule = XTTSModule; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async downloadModel( |
|
|
variant: 'q4_k' | 'q8' | 'f16' = 'q4_k', |
|
|
progressCallback?: (progress: number) => void |
|
|
): Promise<string> { |
|
|
const HF_REPO = 'GenMedLabs/xtts-gguf'; |
|
|
const HF_BASE = `https://huggingface.co/${HF_REPO}/resolve/main`; |
|
|
|
|
|
const modelFile = `gguf/xtts_v2_${variant}.gguf`; |
|
|
const url = `${HF_BASE}/${modelFile}?download=true`; |
|
|
const destPath = `${RNFS.DocumentDirectoryPath}/xtts_${variant}.gguf`; |
|
|
|
|
|
|
|
|
const exists = await RNFS.exists(destPath); |
|
|
if (exists) { |
|
|
console.log(`Model already downloaded at ${destPath}`); |
|
|
return destPath; |
|
|
} |
|
|
|
|
|
console.log(`Downloading XTTS ${variant} model...`); |
|
|
|
|
|
|
|
|
const download = RNFS.downloadFile({ |
|
|
fromUrl: url, |
|
|
toFile: destPath, |
|
|
background: true, |
|
|
discretionary: true, |
|
|
progressDivider: 1, |
|
|
progress: (res) => { |
|
|
const progress = res.bytesWritten / res.contentLength; |
|
|
progressCallback?.(progress); |
|
|
}, |
|
|
}); |
|
|
|
|
|
const result = await download.promise; |
|
|
if (result.statusCode !== 200) { |
|
|
throw new Error(`Failed to download model: ${result.statusCode}`); |
|
|
} |
|
|
|
|
|
console.log(`Model downloaded to ${destPath}`); |
|
|
return destPath; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async initialize( |
|
|
modelPath?: string, |
|
|
options?: InitOptions |
|
|
): Promise<ModelInfo> { |
|
|
|
|
|
if (!modelPath) { |
|
|
modelPath = await this.downloadModel('q4_k'); |
|
|
} |
|
|
|
|
|
|
|
|
const exists = await RNFS.exists(modelPath); |
|
|
if (!exists) { |
|
|
throw new Error(`Model file not found: ${modelPath}`); |
|
|
} |
|
|
|
|
|
|
|
|
const stat = await RNFS.stat(modelPath); |
|
|
console.log(`Loading model: ${stat.size / (1024*1024)}MB`); |
|
|
|
|
|
|
|
|
this.modelInfo = await this.nativeModule.initialize(modelPath, options); |
|
|
this.modelPath = modelPath; |
|
|
|
|
|
console.log(`Model initialized:`); |
|
|
console.log(` Sample rate: ${this.modelInfo.sampleRate}Hz`); |
|
|
console.log(` Languages: ${this.modelInfo.nLanguages}`); |
|
|
console.log(` Memory usage: ${this.modelInfo.memoryMB}MB`); |
|
|
|
|
|
return this.modelInfo; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async speak( |
|
|
text: string, |
|
|
options?: GenerateOptions |
|
|
): Promise<Float32Array> { |
|
|
if (!this.modelInfo?.initialized) { |
|
|
throw new Error('Model not initialized. Call initialize() first.'); |
|
|
} |
|
|
|
|
|
|
|
|
if (options?.language && !this.isValidLanguage(options.language)) { |
|
|
throw new Error(`Unsupported language: ${options.language}`); |
|
|
} |
|
|
|
|
|
|
|
|
const audio = await this.nativeModule.generateAsync(text, options); |
|
|
return audio; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
createStream( |
|
|
text: string, |
|
|
options?: StreamOptions |
|
|
): XTTSStream { |
|
|
if (!this.modelInfo?.initialized) { |
|
|
throw new Error('Model not initialized. Call initialize() first.'); |
|
|
} |
|
|
|
|
|
const handle = this.nativeModule.createStream(text, options); |
|
|
return new XTTSStream(this.nativeModule, handle); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
getSupportedLanguages(): Language[] { |
|
|
return this.nativeModule.getSupportedLanguages() as Language[]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
isValidLanguage(lang: string): boolean { |
|
|
const supported = this.getSupportedLanguages(); |
|
|
return supported.includes(lang as Language); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
getModelInfo(): ModelInfo | null { |
|
|
return this.modelInfo; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cleanup(): void { |
|
|
this.nativeModule.cleanup(); |
|
|
this.modelInfo = null; |
|
|
this.modelPath = null; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export class XTTSStream { |
|
|
private nativeModule: XTTSNativeModule; |
|
|
private handle: StreamHandle; |
|
|
private audioBuffer: Float32Array[] = []; |
|
|
private onDataCallback?: (chunk: Float32Array) => void; |
|
|
private onEndCallback?: () => void; |
|
|
private polling = false; |
|
|
|
|
|
constructor(nativeModule: XTTSNativeModule, handle: StreamHandle) { |
|
|
this.nativeModule = nativeModule; |
|
|
this.handle = handle; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
onData(callback: (chunk: Float32Array) => void): this { |
|
|
this.onDataCallback = callback; |
|
|
return this; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
onEnd(callback: () => void): this { |
|
|
this.onEndCallback = callback; |
|
|
return this; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start(): void { |
|
|
if (this.polling) return; |
|
|
|
|
|
this.polling = true; |
|
|
this.pollForChunks(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private async pollForChunks(): Promise<void> { |
|
|
while (this.polling && this.handle.active) { |
|
|
const chunk = this.nativeModule.getStreamChunk(this.handle, 8192); |
|
|
|
|
|
if (chunk) { |
|
|
this.audioBuffer.push(chunk); |
|
|
this.onDataCallback?.(chunk); |
|
|
} else { |
|
|
|
|
|
this.handle.active = false; |
|
|
this.polling = false; |
|
|
this.onEndCallback?.(); |
|
|
break; |
|
|
} |
|
|
|
|
|
|
|
|
await new Promise(resolve => setTimeout(resolve, 10)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stop(): void { |
|
|
this.polling = false; |
|
|
this.nativeModule.closeStream(this.handle); |
|
|
this.handle.active = false; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
getBuffer(): Float32Array { |
|
|
const totalLength = this.audioBuffer.reduce( |
|
|
(sum, chunk) => sum + chunk.length, 0 |
|
|
); |
|
|
|
|
|
const result = new Float32Array(totalLength); |
|
|
let offset = 0; |
|
|
|
|
|
for (const chunk of this.audioBuffer) { |
|
|
result.set(chunk, offset); |
|
|
offset += chunk.length; |
|
|
} |
|
|
|
|
|
return result; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
isActive(): boolean { |
|
|
return this.handle.active; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
export default new XTTS(); |