Spaces:
Running
Running
| """ | |
| 🎙️ Multi-Engine TTS – Zero-GPU edition | |
| Kokoro │ Veena │ pyttsx3 (fallback) | |
| Routes every synthesis to an idle A100. | |
| """ | |
| import os, tempfile, subprocess, numpy as np | |
| import gradio as gr | |
| import soundfile as sf | |
| import spaces # << Zero-GPU helper | |
| # ------------------------------------------------------------------ | |
| # 1. Engine availability flags | |
| # ------------------------------------------------------------------ | |
| KOKORO_OK = False | |
| VEENA_OK = False | |
| PYT_OK = False | |
| try: | |
| from kokoro import KPipeline | |
| KOKORO_OK = True | |
| except Exception as e: | |
| print("Kokoro unavailable:", e) | |
| try: | |
| import torch, transformers, snac | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from snac import SNAC | |
| VEENA_OK = True | |
| except Exception as e: | |
| print("Veena deps unavailable:", e) | |
| try: | |
| import pyttsx3 | |
| PYT_OK = True | |
| except Exception as e: | |
| print("pyttsx3 unavailable:", e) | |
| # ------------------------------------------------------------------ | |
| # 2. Lazy model loader (runs once per GPU worker) | |
| # ------------------------------------------------------------------ | |
| kokoro_pipe = None | |
| veena_model = None | |
| veena_tok = None | |
| veena_snac = None | |
| def load_kokoro(): | |
| global kokoro_pipe | |
| if kokoro_pipe is None and KOKORO_OK: | |
| kokoro_pipe = KPipeline(lang_code='a') | |
| return kokoro_pipe | |
| def load_veena(): | |
| global veena_model, veena_tok, veena_snac | |
| if veena_model is None and VEENA_OK: | |
| bnb = BitsAndBytesConfig(load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16) | |
| veena_model = AutoModelForCausalLM.from_pretrained( | |
| "maya-research/veena-tts", | |
| quantization_config=bnb, | |
| device_map="auto", | |
| trust_remote_code=True) | |
| veena_tok = AutoTokenizer.from_pretrained("maya-research/veena-tts", | |
| trust_remote_code=True) | |
| veena_snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() | |
| if torch.cuda.is_available(): | |
| veena_snac = veena_snac.cuda() | |
| return veena_model | |
| # ------------------------------------------------------------------ | |
| # 3. Generation helpers (CPU→GPU off-load) | |
| # ------------------------------------------------------------------ | |
| AUDIO_CODE_BASE_OFFSET = 128266 | |
| START_OF_SPEECH_TOKEN = 128257 | |
| END_OF_SPEECH_TOKEN = 128258 | |
| START_OF_HUMAN_TOKEN = 128259 | |
| END_OF_HUMAN_TOKEN = 128260 | |
| START_OF_AI_TOKEN = 128261 | |
| END_OF_AI_TOKEN = 128262 | |
| def decode_snac(tokens): | |
| if len(tokens) % 7: | |
| return None | |
| codes = [[] for _ in range(3)] | |
| offsets = [AUDIO_CODE_BASE_OFFSET + i*4096 for i in range(7)] | |
| for i in range(0, len(tokens), 7): | |
| codes[0].append(tokens[i] - offsets[0]) | |
| codes[1].extend([tokens[i+1]-offsets[1], tokens[i+4]-offsets[4]]) | |
| codes[2].extend([tokens[i+2]-offsets[2], tokens[i+3]-offsets[3], | |
| tokens[i+5]-offsets[5], tokens[i+6]-offsets[6]]) | |
| device = veena_snac.device | |
| hierarchical = [torch.tensor(c, dtype=torch.int32, device=device).unsqueeze(0) | |
| for c in codes] | |
| with torch.no_grad(): | |
| wav = veena_snac.decode(hierarchical).squeeze().clamp(-1,1).cpu().numpy() | |
| return wav | |
| def tts_veena(text, speaker, temperature, top_p): | |
| load_veena() | |
| prompt = f"<spk_{speaker}> {text}" | |
| tok = veena_tok.encode(prompt, add_special_tokens=False) | |
| input_ids = [START_OF_HUMAN_TOKEN] + tok + [END_OF_HUMAN_TOKEN, | |
| START_OF_AI_TOKEN, START_OF_SPEECH_TOKEN] | |
| input_ids = torch.tensor([input_ids], device=veena_model.device) | |
| max_new = min(int(len(text)*1.3)*7 + 21, 700) | |
| out = veena_model.generate( | |
| input_ids, | |
| max_new_tokens=max_new, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=1.05, | |
| pad_token_id=veena_tok.pad_token_id, | |
| eos_token_id=[END_OF_SPEECH_TOKEN, END_OF_AI_TOKEN]) | |
| gen = out[0, len(input_ids[0]):].tolist() | |
| snac_toks = [t for t in gen if AUDIO_CODE_BASE_OFFSET <= t < AUDIO_CODE_BASE_OFFSET+7*4096] | |
| if not snac_toks: | |
| raise RuntimeError("No audio tokens produced") | |
| return decode_snac(snac_toks) | |
| def tts_kokoro(text, voice, speed): | |
| pipe = load_kokoro() | |
| generator = pipe(text, voice=voice, speed=speed) | |
| for gs, ps, audio in generator: | |
| return audio | |
| raise RuntimeError("Kokoro generation failed") | |
| def tts_pyttsx3(text, rate, volume): | |
| engine = pyttsx3.init() | |
| engine.setProperty('rate', rate) | |
| engine.setProperty('volume', volume) | |
| fd, path = tempfile.mkstemp(suffix='.wav') | |
| os.close(fd) | |
| engine.save_to_file(text, path) | |
| engine.runAndWait() | |
| wav, sr = sf.read(path) | |
| os.remove(path) | |
| return wav | |
| # ------------------------------------------------------------------ | |
| # 4. ZERO-GPU ENTRY POINT (decorated) | |
| # ------------------------------------------------------------------ | |
| def synthesise(text, engine, voice, speed, speaker, temperature, top_p, rate, vol): | |
| if not text.strip(): | |
| raise gr.Error("Please enter some text.") | |
| if engine == "kokoro" and KOKORO_OK: | |
| wav = tts_kokoro(text, voice=voice, speed=speed) | |
| elif engine == "veena" and VEENA_OK: | |
| wav = tts_veena(text, speaker=speaker, temperature=temperature, top_p=top_p) | |
| elif engine == "pyttsx3" and PYT_OK: | |
| wav = tts_pyttsx3(text, rate=rate, volume=vol) | |
| else: | |
| raise gr.Error(f"{engine} is not available on this Space.") | |
| fd, tmp = tempfile.mkstemp(suffix='.wav') | |
| os.close(fd) | |
| sf.write(tmp, wav, 24000) | |
| return tmp | |
| # ------------------------------------------------------------------ | |
| # 5. Gradio UI (unchanged visuals) | |
| # ------------------------------------------------------------------ | |
| css = """footer {visibility: hidden} #col-left {max-width: 320px}""" | |
| with gr.Blocks(css=css, title="Multi-Engine TTS – Zero-GPU") as demo: | |
| gr.Markdown("## 🎙️ Multi-Engine TTS Demo – Zero-GPU \n*Kokoro ‑ Veena ‑ pyttsx3*") | |
| with gr.Row(): | |
| with gr.Column(elem_id="col-left"): | |
| engine = gr.Radio(label="Engine", | |
| choices=[e for e in ["kokoro","veena","pyttsx3"] | |
| if globals().get({"pyttsx3":"PYT_OK"}.get(e,e.upper()+"_OK"), False)], | |
| value="kokoro" if KOKORO_OK else | |
| "veena" if VEENA_OK else "pyttsx3") | |
| with gr.Group(visible=KOKORO_OK) as kokoro_box: | |
| voice = gr.Dropdown(label="Voice", | |
| choices=['af_heart','af_sky','af_mist','af_dusk'], | |
| value='af_heart') | |
| speed = gr.Slider(0.5, 2.0, 1.0, step=0.1, label="Speed") | |
| with gr.Group(visible=VEENA_OK) as veena_box: | |
| speaker = gr.Dropdown(label="Speaker", | |
| choices=['kavya','agastya','maitri','vinaya'], | |
| value='kavya') | |
| temperature = gr.Slider(0.1, 1.0, 0.4, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p") | |
| with gr.Group(visible=PYT_OK) as pyttsx3_box: | |
| rate = gr.Slider(50, 300, 180, step=5, label="Words / min") | |
| vol = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Volume") | |
| with gr.Column(scale=3): | |
| text = gr.Textbox(label="Text to speak", | |
| placeholder="Type or paste text here …", | |
| lines=6, max_lines=12) | |
| btn = gr.Button("🎧 Synthesise", variant="primary") | |
| audio_out = gr.Audio(label="Generated speech", type="filepath") | |
| # show/hide panels | |
| def switch_panel(e): | |
| return (gr.update(visible=e=="kokoro"), | |
| gr.update(visible=e=="veena"), | |
| gr.update(visible=e=="pyttsx3")) | |
| engine.change(switch_panel, inputs=engine, | |
| outputs=[kokoro_box, veena_box, pyttsx3_box]) | |
| # binding | |
| btn.click(synthesise, | |
| inputs=[text, engine, voice, speed, speaker, | |
| temperature, top_p, rate, vol], | |
| outputs=audio_out) | |
| gr.Markdown("### Tips \n" | |
| "- **Kokoro** – fastest, good quality English \n" | |
| "- **Veena** – multilingual, GPU-friendly (4-bit) \n" | |
| "- **pyttsx3** – offline fallback, any language \n" | |
| "Audio is returned as 24 kHz WAV.") | |
| # ------------------------------------------------------------------ | |
| # 6. Launch | |
| # ------------------------------------------------------------------ | |
| demo.launch() |