|
|
import { pipeline, TextToAudioPipeline } from "@huggingface/transformers"; |
|
|
import { split } from "./splitter"; |
|
|
import type { RawAudio } from "@huggingface/transformers"; |
|
|
|
|
|
const MODEL_ID = "onnx-community/Supertonic-TTS-ONNX"; |
|
|
const VOICES_URL = `https://huggingface.co/${MODEL_ID}/resolve/main/voices/`; |
|
|
|
|
|
let pipelinePromise: Promise<TextToAudioPipeline> | null = null; |
|
|
let embeddingsPromise: Promise<Record<string, Float32Array>> | null = null; |
|
|
|
|
|
export async function loadPipeline(progressCallback: (info: any) => void) { |
|
|
return pipelinePromise ??= (async () => { |
|
|
|
|
|
const tts = (await pipeline("text-to-speech", MODEL_ID, { |
|
|
device: "webgpu", |
|
|
progress_callback: progressCallback, |
|
|
})) as TextToAudioPipeline; |
|
|
|
|
|
|
|
|
await tts("Hello", { |
|
|
speaker_embeddings: new Float32Array(1 * 101 * 128), |
|
|
num_inference_steps: 1, |
|
|
speed: 1.0, |
|
|
}); |
|
|
|
|
|
return tts; |
|
|
})(); |
|
|
} |
|
|
|
|
|
export async function loadEmbeddings() { |
|
|
return (embeddingsPromise ??= (async () => { |
|
|
const [female, male] = await Promise.all([ |
|
|
fetch(`${VOICES_URL}F1.bin`).then((r) => r.arrayBuffer()), |
|
|
fetch(`${VOICES_URL}M1.bin`).then((r) => r.arrayBuffer()), |
|
|
]); |
|
|
return { |
|
|
Female: new Float32Array(female), |
|
|
Male: new Float32Array(male), |
|
|
}; |
|
|
})()); |
|
|
} |
|
|
|
|
|
export interface StreamResult { |
|
|
time: number; |
|
|
audio: RawAudio; |
|
|
text: string; |
|
|
index: number; |
|
|
total: number; |
|
|
} |
|
|
|
|
|
function splitWithConstraints(text: string, { minCharacters = 1, maxCharacters = Infinity } = {}): string[] { |
|
|
if (!text) return []; |
|
|
const rawLines = split(text); |
|
|
const result: string[] = []; |
|
|
let currentBuffer = ""; |
|
|
|
|
|
for (const rawLine of rawLines) { |
|
|
const line = rawLine.trim(); |
|
|
if (!line) continue; |
|
|
if (line.length > maxCharacters) { |
|
|
throw new Error(`A single segment exceeds the maximum character limit of ${maxCharacters} characters.`); |
|
|
} |
|
|
|
|
|
if (currentBuffer) currentBuffer += " "; |
|
|
currentBuffer += line; |
|
|
|
|
|
while (currentBuffer.length > maxCharacters) { |
|
|
result.push(currentBuffer.slice(0, maxCharacters)); |
|
|
currentBuffer = currentBuffer.slice(maxCharacters); |
|
|
} |
|
|
if (currentBuffer.length >= minCharacters) { |
|
|
result.push(currentBuffer); |
|
|
currentBuffer = ""; |
|
|
} |
|
|
} |
|
|
if (currentBuffer) result.push(currentBuffer); |
|
|
return result; |
|
|
} |
|
|
|
|
|
export async function* streamTTS( |
|
|
text: string, |
|
|
tts: TextToAudioPipeline, |
|
|
speaker_embeddings: Float32Array, |
|
|
quality: number, |
|
|
speed: number, |
|
|
): AsyncGenerator<StreamResult> { |
|
|
const chunks = splitWithConstraints(text, { |
|
|
minCharacters: 100, |
|
|
maxCharacters: 1000, |
|
|
}); |
|
|
|
|
|
if (chunks.length === 0) chunks.push(text); |
|
|
|
|
|
for (let i = 0; i < chunks.length; ++i) { |
|
|
const chunk = chunks[i]; |
|
|
if (!chunk.trim()) continue; |
|
|
|
|
|
const output = (await tts(chunk, { |
|
|
speaker_embeddings, |
|
|
num_inference_steps: quality, |
|
|
speed, |
|
|
})) as RawAudio; |
|
|
|
|
|
if (i < chunks.length - 1) { |
|
|
|
|
|
const silenceSamples = Math.floor(0.5 * output.sampling_rate); |
|
|
const padded = new Float32Array(output.audio.length + silenceSamples); |
|
|
padded.set(output.audio); |
|
|
output.audio = padded; |
|
|
} |
|
|
yield { |
|
|
time: performance.now(), |
|
|
audio: output, |
|
|
text: chunk, |
|
|
index: i + 1, |
|
|
total: chunks.length, |
|
|
}; |
|
|
} |
|
|
} |
|
|
|
|
|
export function createAudioBlob(chunks: Float32Array[], sampling_rate: number): Blob { |
|
|
const totalLength = chunks.reduce((acc, chunk) => acc + chunk.length, 0); |
|
|
|
|
|
|
|
|
const buffer = new ArrayBuffer(44); |
|
|
const view = new DataView(buffer); |
|
|
|
|
|
|
|
|
writeString(view, 0, "RIFF"); |
|
|
view.setUint32(4, 36 + totalLength * 4, true); |
|
|
writeString(view, 8, "WAVE"); |
|
|
|
|
|
|
|
|
writeString(view, 12, "fmt "); |
|
|
view.setUint32(16, 16, true); |
|
|
view.setUint16(20, 3, true); |
|
|
view.setUint16(22, 1, true); |
|
|
view.setUint32(24, sampling_rate, true); |
|
|
view.setUint32(28, sampling_rate * 4, true); |
|
|
view.setUint16(32, 4, true); |
|
|
view.setUint16(34, 32, true); |
|
|
|
|
|
|
|
|
writeString(view, 36, "data"); |
|
|
view.setUint32(40, totalLength * 4, true); |
|
|
|
|
|
return new Blob([buffer, ...chunks as any], { type: "audio/wav" }); |
|
|
} |
|
|
|
|
|
function writeString(view: DataView, offset: number, string: string) { |
|
|
for (let i = 0; i < string.length; i++) { |
|
|
view.setUint8(offset + i, string.charCodeAt(i)); |
|
|
} |
|
|
} |
|
|
|