Nymbo commited on
Commit
b0e1ce1
·
verified ·
1 Parent(s): 9ec85a2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +457 -0
app.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import io
3
+ import wave
4
+ import numpy as np
5
+
6
+ # Lazy imports for optional dependencies
7
+ try:
8
+ import torch # type: ignore
9
+ except Exception: # pragma: no cover
10
+ torch = None # type: ignore
11
+
12
+ try:
13
+ from pocket_tts import TTSModel # type: ignore
14
+ except Exception: # pragma: no cover
15
+ TTSModel = None # type: ignore
16
+
17
+ # Global state for lazy initialization
18
+ _POCKET_STATE = {
19
+ "initialized": False,
20
+ "model": None,
21
+ "voice_states": {},
22
+ "sample_rate": 24000,
23
+ }
24
+
25
+ # Fallback voices from kyutai/tts-voices (used if no local voices found)
26
+ _FALLBACK_VOICES = {
27
+ "alba": "hf://kyutai/tts-voices/alba-mackenna/casual.wav",
28
+ "marius": "hf://kyutai/tts-voices/voice-donations/Selfie.wav",
29
+ "javert": "hf://kyutai/tts-voices/voice-donations/Butter.wav",
30
+ "jean": "hf://kyutai/tts-voices/ears/p010/freeform_speech_01.wav",
31
+ "fantine": "hf://kyutai/tts-voices/vctk/p244_023.wav",
32
+ "cosette": "hf://kyutai/tts-voices/expresso/ex04-ex02_confused_001_channel1_499s.wav",
33
+ "eponine": "hf://kyutai/tts-voices/vctk/p262_023.wav",
34
+ "azelma": "hf://kyutai/tts-voices/vctk/p303_023.wav",
35
+ }
36
+
37
+
38
+ def _get_available_voices() -> dict[str, str]:
39
+ """Get available voices, preferring local files over HuggingFace.
40
+
41
+ Scans ./voices/ directory for audio files (WAV, MP3, etc.)
42
+ Falls back to HuggingFace preset voices if no local files found.
43
+ """
44
+ import os
45
+
46
+ voices_dir = os.path.join(os.path.dirname(__file__), "voices")
47
+ local_voices = {}
48
+
49
+ if os.path.exists(voices_dir):
50
+ for f in os.listdir(voices_dir):
51
+ # Support common audio formats
52
+ if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a')):
53
+ voice_name = os.path.splitext(f)[0]
54
+ local_voices[voice_name] = os.path.join(voices_dir, f)
55
+
56
+ # If we found local voices, use those exclusively
57
+ if local_voices:
58
+ print(f"Found {len(local_voices)} local voice(s): {list(local_voices.keys())}")
59
+ return local_voices
60
+
61
+ # Fall back to HuggingFace voices
62
+ print("No local voices found, using HuggingFace preset voices")
63
+ return _FALLBACK_VOICES
64
+
65
+
66
+ # Scan voices at import time
67
+ PRESET_VOICES = _get_available_voices()
68
+
69
+
70
+ def _init_pocket(
71
+ temp: float = 0.7,
72
+ lsd_decode_steps: int = 1,
73
+ noise_clamp: float | None = None,
74
+ eos_threshold: float = -4.0,
75
+ ) -> None:
76
+ """Lazy initialization of the Pocket TTS model."""
77
+ if _POCKET_STATE["initialized"]:
78
+ return
79
+
80
+ if TTSModel is None:
81
+ raise gr.Error(
82
+ "pocket-tts is not installed. Please install with: pip install pocket-tts"
83
+ )
84
+
85
+ if torch is None:
86
+ raise gr.Error("PyTorch is not installed. Please install torch>=2.5.0")
87
+
88
+ print("Initializing Pocket TTS...")
89
+
90
+ # Auto-detect device: CPU by default, CUDA if available
91
+ # Note: The pocket-tts docs mention GPU doesn't provide speedup for this model
92
+ device = "cuda" if torch.cuda.is_available() else "cpu"
93
+ print(f"Using device: {device}")
94
+
95
+ try:
96
+ model = TTSModel.load_model(
97
+ temp=float(temp),
98
+ lsd_decode_steps=int(lsd_decode_steps),
99
+ noise_clamp=float(noise_clamp) if noise_clamp is not None else None,
100
+ eos_threshold=float(eos_threshold),
101
+ )
102
+ _POCKET_STATE.update({
103
+ "initialized": True,
104
+ "model": model,
105
+ "sample_rate": model.sample_rate,
106
+ })
107
+ print(f"Pocket TTS initialized. Sample rate: {model.sample_rate} Hz")
108
+ except Exception as e:
109
+ raise gr.Error(f"Failed to initialize Pocket TTS model: {str(e)}")
110
+
111
+
112
+ def _convert_to_wav(audio_path: str) -> str:
113
+ """Convert audio file to WAV format if needed.
114
+
115
+ Returns the path to a WAV file (original if already WAV, or converted temp file).
116
+ Uses pydub for MP3 (requires ffmpeg), soundfile for other formats.
117
+ """
118
+ import tempfile
119
+
120
+ # Check if already WAV
121
+ if audio_path.lower().endswith('.wav'):
122
+ return audio_path
123
+
124
+ print(f"Converting {audio_path} to WAV format...")
125
+
126
+ # Create temp file path
127
+ import os
128
+ tmp_fd, wav_path = tempfile.mkstemp(suffix=".wav")
129
+ os.close(tmp_fd)
130
+
131
+ # Try pydub first (better MP3 support via ffmpeg)
132
+ try:
133
+ from pydub import AudioSegment
134
+ audio = AudioSegment.from_file(audio_path)
135
+ audio.export(wav_path, format="wav")
136
+ print(f"Converted via pydub to: {wav_path}")
137
+ return wav_path
138
+ except ImportError:
139
+ pass # pydub not installed, try soundfile
140
+ except Exception as e:
141
+ print(f"pydub conversion failed: {e}, trying soundfile...")
142
+
143
+ # Fall back to soundfile
144
+ try:
145
+ import soundfile as sf
146
+ audio_data, sample_rate = sf.read(audio_path)
147
+ sf.write(wav_path, audio_data, sample_rate)
148
+ print(f"Converted via soundfile to: {wav_path}")
149
+ return wav_path
150
+ except Exception as e:
151
+ raise gr.Error(f"Failed to convert audio file: {str(e)}. Please upload a WAV file directly or install pydub+ffmpeg for MP3 support.")
152
+
153
+
154
+ def _get_voice_state(voice_name: str | None, custom_audio_path: str | None):
155
+ """Get or create voice state for generation.
156
+
157
+ Args:
158
+ voice_name: Name of preset voice (alba, marius, etc.)
159
+ custom_audio_path: Path to custom audio file for voice cloning
160
+
161
+ Returns:
162
+ Voice state dict for the model
163
+ """
164
+ model = _POCKET_STATE["model"]
165
+
166
+ # Custom audio takes priority
167
+ if custom_audio_path:
168
+ print(f"Loading custom voice from: {custom_audio_path}")
169
+ # Convert to WAV if needed
170
+ wav_path = _convert_to_wav(custom_audio_path)
171
+ return model.get_state_for_audio_prompt(wav_path)
172
+
173
+ # Use preset voice
174
+ if not voice_name or voice_name not in PRESET_VOICES:
175
+ # Default to first available voice
176
+ voice_name = list(PRESET_VOICES.keys())[0] if PRESET_VOICES else None
177
+ if not voice_name:
178
+ raise gr.Error("No voices available. Add audio files to the voices/ directory.")
179
+
180
+ # Check cache
181
+ if voice_name in _POCKET_STATE["voice_states"]:
182
+ return _POCKET_STATE["voice_states"][voice_name]
183
+
184
+ # Load and cache voice state
185
+ voice_path = PRESET_VOICES[voice_name]
186
+ print(f"Loading preset voice '{voice_name}' from: {voice_path}")
187
+
188
+ # Convert to WAV if needed (local files may be MP3, etc.)
189
+ wav_path = _convert_to_wav(voice_path)
190
+ voice_state = model.get_state_for_audio_prompt(wav_path)
191
+ _POCKET_STATE["voice_states"][voice_name] = voice_state
192
+ return voice_state
193
+
194
+
195
+ def _audio_np_to_int16(audio_np: np.ndarray) -> np.ndarray:
196
+ """Convert float audio array to int16."""
197
+ audio_clipped = np.clip(audio_np, -1.0, 1.0)
198
+ return (audio_clipped * 32767.0).astype(np.int16)
199
+
200
+
201
+ def _wav_bytes_from_int16(audio_int16: np.ndarray, sample_rate: int) -> bytes:
202
+ """Create WAV bytes from int16 audio array."""
203
+ buffer = io.BytesIO()
204
+ with wave.open(buffer, "wb") as wf:
205
+ wf.setnchannels(1)
206
+ wf.setsampwidth(2)
207
+ wf.setframerate(sample_rate)
208
+ wf.writeframes(audio_int16.tobytes())
209
+ return buffer.getvalue()
210
+
211
+
212
+ def _split_into_sentences(text: str) -> list[str]:
213
+ """Split text into sentences for chunk-by-chunk generation.
214
+
215
+ Uses simple punctuation-based splitting for natural speech chunks.
216
+ """
217
+ import re
218
+ # Split on sentence-ending punctuation, keeping the punctuation
219
+ # Handle common patterns: . ! ? and combinations like "..." or "?!"
220
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
221
+ # Filter out empty strings and strip whitespace
222
+ return [s.strip() for s in sentences if s.strip()]
223
+
224
+
225
+ def pocket_tts_stream(
226
+ text: str,
227
+ voice: str,
228
+ custom_audio,
229
+ temperature: float,
230
+ lsd_decode_steps: int,
231
+ noise_clamp: float | None,
232
+ eos_threshold: float,
233
+ frames_after_eos: int,
234
+ ):
235
+ """Generate speech with sentence-level streaming.
236
+
237
+ Splits text into sentences and yields complete audio for each sentence,
238
+ matching Kokoro's smooth streaming pattern.
239
+ """
240
+ if not text or not text.strip():
241
+ raise gr.Error("Please enter text to synthesize.")
242
+
243
+ # Initialize model with current parameters
244
+ _init_pocket(
245
+ temp=temperature,
246
+ lsd_decode_steps=lsd_decode_steps,
247
+ noise_clamp=noise_clamp if noise_clamp and noise_clamp > 0 else None,
248
+ eos_threshold=eos_threshold,
249
+ )
250
+
251
+ model = _POCKET_STATE["model"]
252
+ sample_rate = _POCKET_STATE["sample_rate"]
253
+
254
+ # Get voice state
255
+ custom_path = custom_audio if custom_audio else None
256
+ voice_state = _get_voice_state(voice, custom_path)
257
+
258
+ # Split text into sentences for natural chunking
259
+ sentences = _split_into_sentences(text)
260
+ if not sentences:
261
+ raise gr.Error("No valid sentences found in text.")
262
+
263
+ produced_any = False
264
+
265
+ # Buffer for initial audio - wait for ~5 seconds before yielding first chunk
266
+ # This prevents stuttering from short first sentences
267
+ min_initial_samples = int(sample_rate * 5) # 5 seconds of audio
268
+ audio_buffer = []
269
+ buffer_samples = 0
270
+ initial_buffer_yielded = False
271
+
272
+ try:
273
+ for idx, sentence in enumerate(sentences):
274
+ # Generate complete audio for this sentence (non-streaming per sentence)
275
+ audio = model.generate_audio(
276
+ voice_state,
277
+ sentence,
278
+ frames_after_eos=frames_after_eos if frames_after_eos > 0 else None,
279
+ copy_state=True,
280
+ )
281
+ produced_any = True
282
+
283
+ # Convert tensor to numpy
284
+ audio_np = audio.cpu().numpy() if hasattr(audio, 'cpu') else audio
285
+
286
+ if not initial_buffer_yielded:
287
+ # Accumulate in buffer until we have enough audio
288
+ audio_buffer.append(audio_np)
289
+ buffer_samples += len(audio_np)
290
+
291
+ # Check if we have enough or this is the last sentence
292
+ if buffer_samples >= min_initial_samples or idx == len(sentences) - 1:
293
+ # Yield the accumulated buffer
294
+ combined = np.concatenate(audio_buffer, axis=0)
295
+ audio_int16 = _audio_np_to_int16(combined)
296
+ yield _wav_bytes_from_int16(audio_int16, sample_rate)
297
+ audio_buffer = []
298
+ buffer_samples = 0
299
+ initial_buffer_yielded = True
300
+ else:
301
+ # After initial buffer, yield each sentence immediately
302
+ audio_int16 = _audio_np_to_int16(audio_np)
303
+ yield _wav_bytes_from_int16(audio_int16, sample_rate)
304
+
305
+ except gr.Error:
306
+ raise
307
+ except Exception as e:
308
+ raise gr.Error(f"Error during speech generation: {str(e)[:200]}...")
309
+
310
+ if not produced_any:
311
+ raise gr.Error("No audio was generated.")
312
+
313
+
314
+ def generate_tts(
315
+ text: str,
316
+ voice: str,
317
+ custom_audio,
318
+ temperature: float,
319
+ lsd_decode_steps: int,
320
+ noise_clamp: float,
321
+ eos_threshold: float,
322
+ frames_after_eos: int,
323
+ ):
324
+ """Main streaming dispatcher for Pocket TTS."""
325
+ yield from pocket_tts_stream(
326
+ text,
327
+ voice,
328
+ custom_audio,
329
+ temperature,
330
+ lsd_decode_steps,
331
+ noise_clamp,
332
+ eos_threshold,
333
+ frames_after_eos,
334
+ )
335
+
336
+
337
+ # --- Gradio UI ---
338
+ with gr.Blocks() as demo:
339
+ gr.HTML(
340
+ "<h1 style='text-align: center;'>Pocket-TTS</h1>"
341
+ "<p style='text-align: center;'>Powered by kyutai/pocket-tts | Lightweight TTS on CPU</p>"
342
+ )
343
+
344
+ with gr.Row():
345
+ with gr.Column():
346
+ # Text input
347
+ text_input = gr.Textbox(
348
+ label="Input Text",
349
+ placeholder="Enter the text you want to convert to speech here...",
350
+ lines=5,
351
+ value="Hello! This is a test of the Pocket text to speech model. It runs efficiently on CPU and supports voice cloning.",
352
+ )
353
+
354
+ # Voice selection
355
+ with gr.Group():
356
+ gr.Markdown("### Voice Selection")
357
+ gr.Markdown("Select a preset voice OR upload your own WAV file for voice cloning.")
358
+
359
+ voice_dropdown = gr.Dropdown(
360
+ choices=list(PRESET_VOICES.keys()),
361
+ label="Preset Voice",
362
+ value=list(PRESET_VOICES.keys())[0] if PRESET_VOICES else None,
363
+ info="Select a pre-loaded voice. Ignored if custom audio is uploaded.",
364
+ )
365
+
366
+ gr.Markdown("--- OR ---")
367
+
368
+ ref_audio_input = gr.Audio(
369
+ label="Custom Voice (WAV)",
370
+ type="filepath",
371
+ sources=["upload", "microphone"],
372
+ )
373
+
374
+ generate_btn = gr.Button(
375
+ "Generate Speech",
376
+ variant="primary",
377
+ )
378
+
379
+ with gr.Column():
380
+ audio_output = gr.Audio(
381
+ label="Generated Speech",
382
+ streaming=True,
383
+ autoplay=True,
384
+ buttons=["download"],
385
+ )
386
+
387
+ with gr.Accordion("Advanced Options", open=False):
388
+ temp_slider = gr.Slider(
389
+ minimum=0.1,
390
+ maximum=1.5,
391
+ value=0.7,
392
+ step=0.05,
393
+ label="Temperature",
394
+ info="Controls randomness. Higher = more varied, lower = more consistent.",
395
+ )
396
+ lsd_steps_slider = gr.Slider(
397
+ minimum=1,
398
+ maximum=10,
399
+ value=1,
400
+ step=1,
401
+ label="LSD Decode Steps",
402
+ info="Number of generation steps. Higher = potentially better quality but slower.",
403
+ )
404
+ noise_clamp_slider = gr.Slider(
405
+ minimum=0.0,
406
+ maximum=5.0,
407
+ value=0.0,
408
+ step=0.1,
409
+ label="Noise Clamp",
410
+ info="Maximum value for noise sampling. 0 = disabled.",
411
+ )
412
+ eos_threshold_slider = gr.Slider(
413
+ minimum=-10.0,
414
+ maximum=0.0,
415
+ value=-4.0,
416
+ step=0.5,
417
+ label="EOS Threshold",
418
+ info="Threshold for end-of-sequence detection. More negative = longer audio.",
419
+ )
420
+ frames_after_eos_slider = gr.Slider(
421
+ minimum=0,
422
+ maximum=10,
423
+ value=2,
424
+ step=1,
425
+ label="Frames After EOS",
426
+ info="Additional frames to generate after EOS detection.",
427
+ )
428
+
429
+ # Connect inputs
430
+ generate_inputs = [
431
+ text_input,
432
+ voice_dropdown,
433
+ ref_audio_input,
434
+ temp_slider,
435
+ lsd_steps_slider,
436
+ noise_clamp_slider,
437
+ eos_threshold_slider,
438
+ frames_after_eos_slider,
439
+ ]
440
+
441
+ generate_btn.click(
442
+ fn=generate_tts,
443
+ inputs=generate_inputs,
444
+ outputs=audio_output,
445
+ api_name="generate_speech",
446
+ )
447
+
448
+ text_input.submit(
449
+ fn=generate_tts,
450
+ inputs=generate_inputs,
451
+ outputs=audio_output,
452
+ api_name="generate_speech_enter",
453
+ )
454
+
455
+
456
+ if __name__ == "__main__":
457
+ demo.queue().launch(debug=True, theme="Nymbo/Nymbo_Theme")