Text2Audio / app.py
IFMedTechdemo's picture
Update app.py
d1aa924 verified
raw
history blame
8.88 kB
"""
🎙️ Multi-Engine TTS – Zero-GPU edition
Kokoro │ Veena │ pyttsx3 (fallback)
Routes every synthesis to an idle A100.
"""
import os, tempfile, subprocess, numpy as np
import gradio as gr
import soundfile as sf
import spaces # << Zero-GPU helper
# ------------------------------------------------------------------
# 1. Engine availability flags
# ------------------------------------------------------------------
KOKORO_OK = False
VEENA_OK = False
PYT_OK = False
try:
from kokoro import KPipeline
KOKORO_OK = True
except Exception as e:
print("Kokoro unavailable:", e)
try:
import torch, transformers, snac
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from snac import SNAC
VEENA_OK = True
except Exception as e:
print("Veena deps unavailable:", e)
try:
import pyttsx3
PYT_OK = True
except Exception as e:
print("pyttsx3 unavailable:", e)
# ------------------------------------------------------------------
# 2. Lazy model loader (runs once per GPU worker)
# ------------------------------------------------------------------
kokoro_pipe = None
veena_model = None
veena_tok = None
veena_snac = None
def load_kokoro():
global kokoro_pipe
if kokoro_pipe is None and KOKORO_OK:
kokoro_pipe = KPipeline(lang_code='a')
return kokoro_pipe
def load_veena():
global veena_model, veena_tok, veena_snac
if veena_model is None and VEENA_OK:
bnb = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16)
veena_model = AutoModelForCausalLM.from_pretrained(
"maya-research/veena-tts",
quantization_config=bnb,
device_map="auto",
trust_remote_code=True)
veena_tok = AutoTokenizer.from_pretrained("maya-research/veena-tts",
trust_remote_code=True)
veena_snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
if torch.cuda.is_available():
veena_snac = veena_snac.cuda()
return veena_model
# ------------------------------------------------------------------
# 3. Generation helpers (CPU→GPU off-load)
# ------------------------------------------------------------------
AUDIO_CODE_BASE_OFFSET = 128266
START_OF_SPEECH_TOKEN = 128257
END_OF_SPEECH_TOKEN = 128258
START_OF_HUMAN_TOKEN = 128259
END_OF_HUMAN_TOKEN = 128260
START_OF_AI_TOKEN = 128261
END_OF_AI_TOKEN = 128262
def decode_snac(tokens):
if len(tokens) % 7:
return None
codes = [[] for _ in range(3)]
offsets = [AUDIO_CODE_BASE_OFFSET + i*4096 for i in range(7)]
for i in range(0, len(tokens), 7):
codes[0].append(tokens[i] - offsets[0])
codes[1].extend([tokens[i+1]-offsets[1], tokens[i+4]-offsets[4]])
codes[2].extend([tokens[i+2]-offsets[2], tokens[i+3]-offsets[3],
tokens[i+5]-offsets[5], tokens[i+6]-offsets[6]])
device = veena_snac.device
hierarchical = [torch.tensor(c, dtype=torch.int32, device=device).unsqueeze(0)
for c in codes]
with torch.no_grad():
wav = veena_snac.decode(hierarchical).squeeze().clamp(-1,1).cpu().numpy()
return wav
def tts_veena(text, speaker, temperature, top_p):
load_veena()
prompt = f"<spk_{speaker}> {text}"
tok = veena_tok.encode(prompt, add_special_tokens=False)
input_ids = [START_OF_HUMAN_TOKEN] + tok + [END_OF_HUMAN_TOKEN,
START_OF_AI_TOKEN, START_OF_SPEECH_TOKEN]
input_ids = torch.tensor([input_ids], device=veena_model.device)
max_new = min(int(len(text)*1.3)*7 + 21, 700)
out = veena_model.generate(
input_ids,
max_new_tokens=max_new,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.05,
pad_token_id=veena_tok.pad_token_id,
eos_token_id=[END_OF_SPEECH_TOKEN, END_OF_AI_TOKEN])
gen = out[0, len(input_ids[0]):].tolist()
snac_toks = [t for t in gen if AUDIO_CODE_BASE_OFFSET <= t < AUDIO_CODE_BASE_OFFSET+7*4096]
if not snac_toks:
raise RuntimeError("No audio tokens produced")
return decode_snac(snac_toks)
def tts_kokoro(text, voice, speed):
pipe = load_kokoro()
generator = pipe(text, voice=voice, speed=speed)
for gs, ps, audio in generator:
return audio
raise RuntimeError("Kokoro generation failed")
def tts_pyttsx3(text, rate, volume):
engine = pyttsx3.init()
engine.setProperty('rate', rate)
engine.setProperty('volume', volume)
fd, path = tempfile.mkstemp(suffix='.wav')
os.close(fd)
engine.save_to_file(text, path)
engine.runAndWait()
wav, sr = sf.read(path)
os.remove(path)
return wav
# ------------------------------------------------------------------
# 4. ZERO-GPU ENTRY POINT (decorated)
# ------------------------------------------------------------------
@spaces.GPU
def synthesise(text, engine, voice, speed, speaker, temperature, top_p, rate, vol):
if not text.strip():
raise gr.Error("Please enter some text.")
if engine == "kokoro" and KOKORO_OK:
wav = tts_kokoro(text, voice=voice, speed=speed)
elif engine == "veena" and VEENA_OK:
wav = tts_veena(text, speaker=speaker, temperature=temperature, top_p=top_p)
elif engine == "pyttsx3" and PYT_OK:
wav = tts_pyttsx3(text, rate=rate, volume=vol)
else:
raise gr.Error(f"{engine} is not available on this Space.")
fd, tmp = tempfile.mkstemp(suffix='.wav')
os.close(fd)
sf.write(tmp, wav, 24000)
return tmp
# ------------------------------------------------------------------
# 5. Gradio UI (unchanged visuals)
# ------------------------------------------------------------------
css = """footer {visibility: hidden} #col-left {max-width: 320px}"""
with gr.Blocks(css=css, title="Multi-Engine TTS – Zero-GPU") as demo:
gr.Markdown("## 🎙️ Multi-Engine TTS Demo – Zero-GPU \n*Kokoro ‑ Veena ‑ pyttsx3*")
with gr.Row():
with gr.Column(elem_id="col-left"):
engine = gr.Radio(label="Engine",
choices=[e for e in ["kokoro","veena","pyttsx3"]
if globals().get({"pyttsx3":"PYT_OK"}.get(e,e.upper()+"_OK"), False)],
value="kokoro" if KOKORO_OK else
"veena" if VEENA_OK else "pyttsx3")
with gr.Group(visible=KOKORO_OK) as kokoro_box:
voice = gr.Dropdown(label="Voice",
choices=['af_heart','af_sky','af_mist','af_dusk'],
value='af_heart')
speed = gr.Slider(0.5, 2.0, 1.0, step=0.1, label="Speed")
with gr.Group(visible=VEENA_OK) as veena_box:
speaker = gr.Dropdown(label="Speaker",
choices=['kavya','agastya','maitri','vinaya'],
value='kavya')
temperature = gr.Slider(0.1, 1.0, 0.4, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p")
with gr.Group(visible=PYT_OK) as pyttsx3_box:
rate = gr.Slider(50, 300, 180, step=5, label="Words / min")
vol = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Volume")
with gr.Column(scale=3):
text = gr.Textbox(label="Text to speak",
placeholder="Type or paste text here …",
lines=6, max_lines=12)
btn = gr.Button("🎧 Synthesise", variant="primary")
audio_out = gr.Audio(label="Generated speech", type="filepath")
# show/hide panels
def switch_panel(e):
return (gr.update(visible=e=="kokoro"),
gr.update(visible=e=="veena"),
gr.update(visible=e=="pyttsx3"))
engine.change(switch_panel, inputs=engine,
outputs=[kokoro_box, veena_box, pyttsx3_box])
# binding
btn.click(synthesise,
inputs=[text, engine, voice, speed, speaker,
temperature, top_p, rate, vol],
outputs=audio_out)
gr.Markdown("### Tips \n"
"- **Kokoro** – fastest, good quality English \n"
"- **Veena** – multilingual, GPU-friendly (4-bit) \n"
"- **pyttsx3** – offline fallback, any language \n"
"Audio is returned as 24 kHz WAV.")
# ------------------------------------------------------------------
# 6. Launch
# ------------------------------------------------------------------
demo.launch()