File size: 4,626 Bytes
a0d4ab9
2beb552
a0d4ab9
2beb552
 
 
 
 
 
 
 
371689d
a034d9a
371689d
 
 
 
 
 
 
 
 
 
 
 
 
 
2beb552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0d4ab9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2beb552
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import { pipeline, TextToAudioPipeline } from "@huggingface/transformers";
import { split } from "./splitter";
import type { RawAudio } from "@huggingface/transformers";

const MODEL_ID = "onnx-community/Supertonic-TTS-ONNX";
const VOICES_URL = `https://huggingface.co/${MODEL_ID}/resolve/main/voices/`;

let pipelinePromise: Promise<TextToAudioPipeline> | null = null;
let embeddingsPromise: Promise<Record<string, Float32Array>> | null = null;

export async function loadPipeline(progressCallback: (info: any) => void) {
  return pipelinePromise ??= (async () => {
    // @ts-ignore
    const tts = (await pipeline("text-to-speech", MODEL_ID, {
      device: "webgpu",
      progress_callback: progressCallback,
    })) as TextToAudioPipeline;

    // Warm up the model to compile shaders
    await tts("Hello", {
      speaker_embeddings: new Float32Array(1 * 101 * 128), // Dummy embedding
      num_inference_steps: 1,
      speed: 1.0,
    });

    return tts;
  })();
}

export async function loadEmbeddings() {
  return (embeddingsPromise ??= (async () => {
    const [female, male] = await Promise.all([
      fetch(`${VOICES_URL}F1.bin`).then((r) => r.arrayBuffer()),
      fetch(`${VOICES_URL}M1.bin`).then((r) => r.arrayBuffer()),
    ]);
    return {
      Female: new Float32Array(female),
      Male: new Float32Array(male),
    };
  })());
}

export interface StreamResult {
  time: number;
  audio: RawAudio;
  text: string;
  index: number;
  total: number;
}

function splitWithConstraints(text: string, { minCharacters = 1, maxCharacters = Infinity } = {}): string[] {
  if (!text) return [];
  const rawLines = split(text);
  const result: string[] = [];
  let currentBuffer = "";

  for (const rawLine of rawLines) {
    const line = rawLine.trim();
    if (!line) continue;
    if (line.length > maxCharacters) {
      throw new Error(`A single segment exceeds the maximum character limit of ${maxCharacters} characters.`);
    }

    if (currentBuffer) currentBuffer += " ";
    currentBuffer += line;

    while (currentBuffer.length > maxCharacters) {
      result.push(currentBuffer.slice(0, maxCharacters));
      currentBuffer = currentBuffer.slice(maxCharacters);
    }
    if (currentBuffer.length >= minCharacters) {
      result.push(currentBuffer);
      currentBuffer = "";
    }
  }
  if (currentBuffer) result.push(currentBuffer);
  return result;
}

export async function* streamTTS(
  text: string,
  tts: TextToAudioPipeline,
  speaker_embeddings: Float32Array,
  quality: number,
  speed: number,
): AsyncGenerator<StreamResult> {
  const chunks = splitWithConstraints(text, {
    minCharacters: 100,
    maxCharacters: 1000,
  });

  if (chunks.length === 0) chunks.push(text);

  for (let i = 0; i < chunks.length; ++i) {
    const chunk = chunks[i];
    if (!chunk.trim()) continue;

    const output = (await tts(chunk, {
      speaker_embeddings,
      num_inference_steps: quality,
      speed,
    })) as RawAudio;

    if (i < chunks.length - 1) {
      // Add 0.5s silence between chunks for more natural flow
      const silenceSamples = Math.floor(0.5 * output.sampling_rate);
      const padded = new Float32Array(output.audio.length + silenceSamples);
      padded.set(output.audio);
      output.audio = padded;
    }
    yield {
      time: performance.now(),
      audio: output,
      text: chunk,
      index: i + 1,
      total: chunks.length,
    };
  }
}

export function createAudioBlob(chunks: Float32Array[], sampling_rate: number): Blob {
  const totalLength = chunks.reduce((acc, chunk) => acc + chunk.length, 0);

  // Create WAV header
  const buffer = new ArrayBuffer(44);
  const view = new DataView(buffer);

  // RIFF chunk descriptor
  writeString(view, 0, "RIFF");
  view.setUint32(4, 36 + totalLength * 4, true); // ChunkSize
  writeString(view, 8, "WAVE");

  // fmt sub-chunk
  writeString(view, 12, "fmt ");
  view.setUint32(16, 16, true); // Subchunk1Size
  view.setUint16(20, 3, true); // AudioFormat (3 = IEEE Float)
  view.setUint16(22, 1, true); // NumChannels (Mono)
  view.setUint32(24, sampling_rate, true); // SampleRate
  view.setUint32(28, sampling_rate * 4, true); // ByteRate
  view.setUint16(32, 4, true); // BlockAlign
  view.setUint16(34, 32, true); // BitsPerSample

  // data sub-chunk
  writeString(view, 36, "data");
  view.setUint32(40, totalLength * 4, true); // Subchunk2Size

  return new Blob([buffer, ...chunks as any], { type: "audio/wav" });
}

function writeString(view: DataView, offset: number, string: string) {
  for (let i = 0; i < string.length; i++) {
    view.setUint8(offset + i, string.charCodeAt(i));
  }
}