TTS-Hub / app.py
Nymbo's picture
Update app.py
0289739 verified
import gradio as gr
import io
import os
import tempfile
import time
import wave
import struct
import numpy as np
from openai import OpenAI
from elevenlabs.client import ElevenLabs
from elevenlabs import stream, play
from elevenlabs.core.api_error import ApiError
# Optional imports for Kokoro TTS (lazy load, CPU-only)
try:
import torch # type: ignore
except Exception: # pragma: no cover
torch = None # type: ignore
try:
from kokoro import KModel, KPipeline # type: ignore
except Exception: # pragma: no cover
KModel = None # type: ignore
KPipeline = None # type: ignore
# ==========================
# All backend TTS and helper functions remain the same.
# No changes are needed in this section.
# ==========================
def pad_buffer(audio):
"""Pad buffer to multiple of 2 bytes for proper audio format"""
buffer_size = len(audio)
element_size = np.dtype(np.int16).itemsize
if buffer_size % element_size != 0:
audio = audio + b'\0' * (element_size - (buffer_size % element_size))
return audio
def openai_tts(text, model, voice, api_key):
"""Generate speech using OpenAI's TTS API"""
if api_key == '':
raise gr.Error('Please enter your OpenAI API Key')
try:
client = OpenAI(api_key=api_key)
response = client.audio.speech.create(
model=model,
voice=voice,
input=text,
)
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
temp_file.write(response.content)
return temp_file.name
except Exception as error:
raise gr.Error(f"An error occurred with OpenAI TTS: {str(error)}")
def elevenlabs_tts(text, voice_id, api_key):
"""Generate speech using ElevenLabs' TTS API"""
if api_key == '':
raise gr.Error('Please enter your ElevenLabs API Key')
try:
client = ElevenLabs(api_key=api_key)
audio = client.text_to_speech.convert(
text=text[:4000],
voice_id=voice_id,
model_id="eleven_multilingual_v2",
output_format="mp3_44100_128"
)
audio_bytes = b''.join(audio)
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
temp_file.write(audio_bytes)
return temp_file.name
except ApiError as e:
if e.status_code == 401:
if "detected_unusual_activity" in str(e):
raise gr.Error("To use ElevenLabs, you'll need a paid ElevenLabs subscription.")
else:
raise gr.Error("Invalid ElevenLabs API key. Please check your API key and try again.")
elif e.status_code == 429:
raise gr.Error("You've reached your ElevenLabs usage limit. Please upgrade your plan or wait for your quota to reset.")
else:
raise gr.Error(f"ElevenLabs API error (status {e.status_code}): {str(e)[:200]}...")
except Exception as e:
raise gr.Error(f"Unexpected error with ElevenLabs TTS: {str(e)[:200]}...")
def get_elevenlabs_voices(api_key):
"""Get available voices from ElevenLabs"""
try:
if api_key:
client = ElevenLabs(api_key=api_key)
voices_response = client.voices.get()
voice_dict = {voice.name: voice.voice_id for voice in voices_response.voices}
return voice_dict
except Exception as e:
print(f"Could not load ElevenLabs voices: {str(e)}")
return {
"Rachel": "21m00Tcm4TlvDq8ikWAM", "Domi": "AZnzlk1XvdvUeBnXmlld",
"Bella": "EXAVITQu4vr4xnSDxMaL", "Antoni": "ErXwobaYiN019PkySvjV",
"Elli": "MF3mGyEYCl7XYWbV9V6O", "Josh": "TxGEqnHWrfWFTfGW9XjX",
"Arnold": "VR6AewLTigWG4xSOukaG", "Adam": "pNInz6obpgDQGcFmaJgB",
"Sam": "yoZ06aMxZJJ28mfd3POQ"
}
# Kokoro TTS (CPU-only)
_KOKORO_STATE = { "initialized": False, "device": "cpu", "model": None, "pipelines": {} }
def _init_kokoro() -> None:
if _KOKORO_STATE["initialized"]:
return
if KModel is None or KPipeline is None:
raise gr.Error("Kokoro is not installed. Please add 'kokoro>=0.9.4' and 'torch' to requirements and install.")
device = "cpu"
model = KModel().to(device).eval()
pipelines = {"a": KPipeline(lang_code="a", model=False)}
try:
pipelines["a"].g2p.lexicon.golds["kokoro"] = "kˈOkəɹO"
except Exception:
pass
_KOKORO_STATE.update({"initialized": True, "device": device, "model": model, "pipelines": pipelines})
def get_kokoro_voices():
"""Get list of available Kokoro voice IDs."""
try:
from huggingface_hub import list_repo_files
files = list_repo_files('hexgrad/Kokoro-82M')
voice_files = [f for f in files if f.endswith('.pt') and f.startswith('voices/')]
voices = [f.replace('voices/', '').replace('.pt', '') for f in voice_files]
return sorted(voices) if voices else ["af_nicole"]
except Exception:
return [
"af_alloy", "af_aoede", "af_bella", "af_heart", "af_jessica", "af_kore", "af_nicole", "af_nova", "af_river", "af_sarah", "af_sky",
"am_adam", "am_echo", "am_eric", "am_fenrir", "am_liam", "am_michael", "am_onyx", "am_puck", "am_santa",
"bf_alice", "bf_emma", "bf_isabella", "bf_lily",
"bm_daniel", "bm_fable", "bm_george", "bm_lewis",
"ef_dora", "em_alex", "em_santa",
"ff_siwis",
"hf_alpha", "hf_beta", "hm_omega", "hm_psi",
"if_sara", "im_nicola",
"jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo",
"pf_dora", "pm_alex", "pm_santa",
"zf_xiaobei", "zf_xiaoni", "zf_xiaoxiao", "zf_xiaoyi", "zm_yunjian", "zm_yunxi", "zm_yunxia", "zm_yunyang"
]
def _audio_np_to_int16(audio_np: np.ndarray) -> np.ndarray:
audio_clipped = np.clip(audio_np, -1.0, 1.0)
return (audio_clipped * 32767.0).astype(np.int16)
def _write_wav_file(audio_int16: np.ndarray, sample_rate: int = 24_000) -> str:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
path = tmp.name
with wave.open(path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio_int16.tobytes())
return path
def _wav_bytes_from_int16(audio_int16: np.ndarray, sample_rate: int = 24_000) -> bytes:
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio_int16.tobytes())
return buffer.getvalue()
def _kokoro_segment_generator(text: str, speed: float, voice: str):
if not text or not text.strip():
raise gr.Error("Please enter text to synthesize.")
_init_kokoro()
model = _KOKORO_STATE["model"]
pipelines = _KOKORO_STATE["pipelines"]
pipeline = pipelines.get("a")
if pipeline is None:
raise gr.Error("Kokoro English pipeline not initialized.")
pack = pipeline.load_voice(voice)
try:
for idx, (_, ps, _) in enumerate(pipeline(text, voice, speed)):
ref_s = pack[len(ps) - 1]
try:
audio = model(ps, ref_s, float(speed))
audio_np = audio.detach().cpu().numpy()
yield audio_np
except Exception as e:
raise gr.Error(f"Error generating audio for segment {idx + 1}: {str(e)[:200]}...")
except gr.Error:
raise
except Exception as e:
raise gr.Error(f"Error during speech generation: {str(e)[:200]}...")
def kokoro_tts(text: str, speed: float, voice: str) -> str:
sr = 24_000
segments = list(_kokoro_segment_generator(text, speed, voice))
if not segments:
raise gr.Error("No audio was generated.")
audio_np = segments[0] if len(segments) == 1 else np.concatenate(segments, axis=0)
audio_int16 = _audio_np_to_int16(audio_np)
return _write_wav_file(audio_int16, sr)
def kokoro_tts_stream(text: str, speed: float, voice: str):
sr = 24_000
produced_any = False
for audio_np in _kokoro_segment_generator(text, speed, voice):
produced_any = True
audio_int16 = _audio_np_to_int16(audio_np)
chunk_bytes = _wav_bytes_from_int16(audio_int16, sr)
yield chunk_bytes
if not produced_any:
raise gr.Error("No audio was generated.")
# Main dispatcher function to handle all services
def _read_file_bytes(path: str) -> bytes:
with open(path, "rb") as file:
data = file.read()
return data
def generate_tts(text, service, openai_api_key, openai_model, openai_voice,
elevenlabs_api_key, elevenlabs_voice, voice_dict,
kokoro_speed, kokoro_voice):
"""Route to appropriate TTS service based on selection"""
if service == "Kokoro":
yield from kokoro_tts_stream(text, kokoro_speed, kokoro_voice)
return
if service == "OpenAI":
file_path = openai_tts(text, openai_model, openai_voice, openai_api_key)
elif service == "ElevenLabs":
voice_id = voice_dict.get(elevenlabs_voice, elevenlabs_voice)
file_path = elevenlabs_tts(text, voice_id, elevenlabs_api_key)
else:
raise gr.Error(f"Unknown service selected: {service}")
try:
audio_bytes = _read_file_bytes(file_path)
finally:
try:
os.remove(file_path)
except OSError:
pass
yield audio_bytes
# Function to update ElevenLabs voices when API key changes
def update_elevenlabs_voices(api_key):
"""Update voice dropdown when API key is entered"""
voice_dict = get_elevenlabs_voices(api_key)
voice_names = list(voice_dict.keys())
# Ensure a default value is set if the list is empty
default_voice = voice_names[0] if voice_names else "Rachel"
return gr.update(choices=voice_names, value=default_voice), voice_dict
# This simple function updates our hidden state based on which tab is selected.
def update_service_state(evt: gr.SelectData):
"""Update the hidden service state textbox with the selected tab's name."""
return evt.value
# ==========================
# Redesigned Gradio UI with Tabs
# ==========================
with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo:
# Add a nice title and description to the app
gr.HTML("<h1 style='text-align: center;'>TTS-Hub</h1><p style='text-align: center;'>Kokoro | OpenAI | ElevenLabs</p>")
# Get default voices for ElevenLabs on startup
default_voice_dict = get_elevenlabs_voices("")
# Store the full voice dictionary (name -> id) in a hidden state component
voice_dict_state = gr.State(default_voice_dict)
# This hidden textbox will store the name of the currently selected service tab.
# It replaces the old radio button group.
service_state = gr.Textbox("Kokoro", visible=False, label="Selected Service")
# Use gr.Tabs to create a clean, tabbed interface for each service.
with gr.Tabs() as tabs:
# Tab 1: Kokoro TTS
with gr.Tab("Kokoro", id="Kokoro"):
# Put all Kokoro-specific controls in this tab
with gr.Row(variant='panel'):
kokoro_speed = gr.Slider(
minimum=0.5, maximum=2.0, value=1.2, step=0.1,
label='Speed'
)
available_voices = get_kokoro_voices()
# Default to 'af_nicole' when available; otherwise, use first available
default_kokoro_voice = (
'af_nicole' if 'af_nicole' in available_voices
else (available_voices[0] if available_voices else 'af_nicole')
)
kokoro_voice = gr.Dropdown(
choices=available_voices,
label='Voice',
value=default_kokoro_voice,
)
# Tab 2: OpenAI TTS
with gr.Tab("OpenAI", id="OpenAI"):
# Put all OpenAI-specific controls in this tab
with gr.Column(variant='panel'):
openai_api_key = gr.Textbox(
type='password',
label='OpenAI API Key',
placeholder='Enter your OpenAI API key (sk-...)',
)
with gr.Row():
openai_model = gr.Dropdown(
choices=['tts-1', 'tts-1-hd'],
label='Model',
value='tts-1-hd',
)
openai_voice = gr.Dropdown(
choices=['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'],
label='Voice',
value='nova',
)
# Tab 3: ElevenLabs TTS
with gr.Tab("ElevenLabs", id="ElevenLabs"):
# Put all ElevenLabs-specific controls in this tab
with gr.Column(variant='panel'):
elevenlabs_api_key = gr.Textbox(
type='password',
label='ElevenLabs API Key',
placeholder='Enter your ElevenLabs API key',
)
elevenlabs_voice = gr.Dropdown(
choices=list(default_voice_dict.keys()),
label='Voice',
value=list(default_voice_dict.keys())[0] if default_voice_dict else "Rachel",
)
# Shared components (input and output) are placed outside the tabs
text_input = gr.Textbox(
label="Input Text",
placeholder="Enter the text you want to convert to speech here...",
lines=5,
)
generate_btn = gr.Button(
"Generate Speech",
variant="primary",
)
audio_output = gr.Audio(
label="Generated Speech",
streaming=True,
autoplay=True,
show_download_button=True,
)
# ==========================
# Event Listeners
# ==========================
# When a tab is selected, update the hidden 'service_state' textbox.
# This tells our backend which service to use.
tabs.select(
fn=update_service_state,
inputs=None, # The selected tab info is passed automatically in the event data
outputs=service_state
)
# This event listener is unchanged: when the ElevenLabs API key is entered/changed,
# it fetches the user's custom voices.
elevenlabs_api_key.change(
fn=update_elevenlabs_voices,
inputs=[elevenlabs_api_key],
outputs=[elevenlabs_voice, voice_dict_state]
)
# Consolidate all inputs needed for the generation function.
# This includes the shared text input, the hidden service state, and all controls from all tabs.
generate_inputs = [
text_input, service_state, openai_api_key, openai_model, openai_voice,
elevenlabs_api_key, elevenlabs_voice, voice_dict_state,
kokoro_speed, kokoro_voice
]
# Trigger the TTS generation when the button is clicked.
generate_btn.click(
fn=generate_tts,
inputs=generate_inputs,
outputs=audio_output,
api_name="generate_speech"
)
# Also trigger the TTS generation when the user presses Enter in the textbox.
text_input.submit(
fn=generate_tts,
inputs=generate_inputs,
outputs=audio_output,
api_name="generate_speech_enter"
)
# Launch the Gradio app
demo.queue().launch(debug=True)