File size: 8,884 Bytes
1d84a6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5330d99
1d84a6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
🎙️ Multi-Engine TTS – Zero-GPU edition
Kokoro │ Veena │ pyttsx3  (fallback)
Routes every synthesis to an idle A100.
"""

import os, tempfile, subprocess, numpy as np
import gradio as gr
import soundfile as sf
import spaces                               # << Zero-GPU helper

# ------------------------------------------------------------------
# 1.  Engine availability flags
# ------------------------------------------------------------------
KOKORO_OK = False
VEENA_OK  = False
PYT_OK    = False

try:
    from kokoro import KPipeline
    KOKORO_OK = True
except Exception as e:
    print("Kokoro unavailable:", e)

try:
    import torch, transformers, snac
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    from snac import SNAC
    VEENA_OK = True
except Exception as e:
    print("Veena deps unavailable:", e)

try:
    import pyttsx3
    PYT_OK = True
except Exception as e:
    print("pyttsx3 unavailable:", e)

# ------------------------------------------------------------------
# 2.  Lazy model loader (runs once per GPU worker)
# ------------------------------------------------------------------
kokoro_pipe = None
veena_model = None
veena_tok   = None
veena_snac  = None

def load_kokoro():
    global kokoro_pipe
    if kokoro_pipe is None and KOKORO_OK:
        kokoro_pipe = KPipeline(lang_code='a')
    return kokoro_pipe

def load_veena():
    global veena_model, veena_tok, veena_snac
    if veena_model is None and VEENA_OK:
        bnb = BitsAndBytesConfig(load_in_4bit=True,
                                 bnb_4bit_quant_type="nf4",
                                 bnb_4bit_compute_dtype=torch.bfloat16)
        veena_model = AutoModelForCausalLM.from_pretrained(
            "maya-research/veena-tts",
            quantization_config=bnb,
            device_map="auto",
            trust_remote_code=True)
        veena_tok   = AutoTokenizer.from_pretrained("maya-research/veena-tts",
                                                    trust_remote_code=True)
        veena_snac  = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
        if torch.cuda.is_available():
            veena_snac = veena_snac.cuda()
    return veena_model

# ------------------------------------------------------------------
# 3.  Generation helpers (CPU→GPU off-load)
# ------------------------------------------------------------------
AUDIO_CODE_BASE_OFFSET = 128266
START_OF_SPEECH_TOKEN  = 128257
END_OF_SPEECH_TOKEN    = 128258
START_OF_HUMAN_TOKEN   = 128259
END_OF_HUMAN_TOKEN     = 128260
START_OF_AI_TOKEN      = 128261
END_OF_AI_TOKEN        = 128262

def decode_snac(tokens):
    if len(tokens) % 7:
        return None
    codes = [[] for _ in range(3)]
    offsets = [AUDIO_CODE_BASE_OFFSET + i*4096 for i in range(7)]
    for i in range(0, len(tokens), 7):
        codes[0].append(tokens[i]   - offsets[0])
        codes[1].extend([tokens[i+1]-offsets[1], tokens[i+4]-offsets[4]])
        codes[2].extend([tokens[i+2]-offsets[2], tokens[i+3]-offsets[3],
                         tokens[i+5]-offsets[5], tokens[i+6]-offsets[6]])
    device = veena_snac.device
    hierarchical = [torch.tensor(c, dtype=torch.int32, device=device).unsqueeze(0)
                    for c in codes]
    with torch.no_grad():
        wav = veena_snac.decode(hierarchical).squeeze().clamp(-1,1).cpu().numpy()
    return wav

def tts_veena(text, speaker, temperature, top_p):
    load_veena()
    prompt = f"<spk_{speaker}> {text}"
    tok  = veena_tok.encode(prompt, add_special_tokens=False)
    input_ids = [START_OF_HUMAN_TOKEN] + tok + [END_OF_HUMAN_TOKEN,
                  START_OF_AI_TOKEN, START_OF_SPEECH_TOKEN]
    input_ids = torch.tensor([input_ids], device=veena_model.device)
    max_new = min(int(len(text)*1.3)*7 + 21, 700)
    out = veena_model.generate(
        input_ids,
        max_new_tokens=max_new,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=1.05,
        pad_token_id=veena_tok.pad_token_id,
        eos_token_id=[END_OF_SPEECH_TOKEN, END_OF_AI_TOKEN])
    gen = out[0, len(input_ids[0]):].tolist()
    snac_toks = [t for t in gen if AUDIO_CODE_BASE_OFFSET <= t < AUDIO_CODE_BASE_OFFSET+7*4096]
    if not snac_toks:
        raise RuntimeError("No audio tokens produced")
    return decode_snac(snac_toks)

def tts_kokoro(text, voice, speed):
    pipe = load_kokoro()
    generator = pipe(text, voice=voice, speed=speed)
    for gs, ps, audio in generator:
        return audio
    raise RuntimeError("Kokoro generation failed")

def tts_pyttsx3(text, rate, volume):
    engine = pyttsx3.init()
    engine.setProperty('rate', rate)
    engine.setProperty('volume', volume)
    fd, path = tempfile.mkstemp(suffix='.wav')
    os.close(fd)
    engine.save_to_file(text, path)
    engine.runAndWait()
    wav, sr = sf.read(path)
    os.remove(path)
    return wav

# ------------------------------------------------------------------
# 4.  ZERO-GPU ENTRY POINT  (decorated)
# ------------------------------------------------------------------
@spaces.GPU
def synthesise(text, engine, voice, speed, speaker, temperature, top_p, rate, vol):
    if not text.strip():
        raise gr.Error("Please enter some text.")
    if engine == "kokoro" and KOKORO_OK:
        wav = tts_kokoro(text, voice=voice, speed=speed)
    elif engine == "veena" and VEENA_OK:
        wav = tts_veena(text, speaker=speaker, temperature=temperature, top_p=top_p)
    elif engine == "pyttsx3" and PYT_OK:
        wav = tts_pyttsx3(text, rate=rate, volume=vol)
    else:
        raise gr.Error(f"{engine} is not available on this Space.")
    fd, tmp = tempfile.mkstemp(suffix='.wav')
    os.close(fd)
    sf.write(tmp, wav, 24000)
    return tmp

# ------------------------------------------------------------------
# 5.  Gradio UI  (unchanged visuals)
# ------------------------------------------------------------------
css = """footer {visibility: hidden} #col-left {max-width: 320px}"""

with gr.Blocks(css=css, title="Multi-Engine TTS – Zero-GPU") as demo:
    gr.Markdown("## 🎙️ Multi-Engine TTS Demo – Zero-GPU  \n*Kokoro ‑ Veena ‑ pyttsx3*")

    with gr.Row():
        with gr.Column(elem_id="col-left"):
            engine = gr.Radio(label="Engine",
                              choices=[e for e in ["kokoro","veena","pyttsx3"]
                                       if globals().get({"pyttsx3":"PYT_OK"}.get(e,e.upper()+"_OK"), False)],
                              value="kokoro" if KOKORO_OK else
                                    "veena"  if VEENA_OK  else "pyttsx3")

            with gr.Group(visible=KOKORO_OK) as kokoro_box:
                voice = gr.Dropdown(label="Voice",
                                    choices=['af_heart','af_sky','af_mist','af_dusk'],
                                    value='af_heart')
                speed = gr.Slider(0.5, 2.0, 1.0, step=0.1, label="Speed")

            with gr.Group(visible=VEENA_OK) as veena_box:
                speaker = gr.Dropdown(label="Speaker",
                                      choices=['kavya','agastya','maitri','vinaya'],
                                      value='kavya')
                temperature = gr.Slider(0.1, 1.0, 0.4, step=0.05, label="Temperature")
                top_p      = gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p")

            with gr.Group(visible=PYT_OK) as pyttsx3_box:
                rate = gr.Slider(50, 300, 180, step=5, label="Words / min")
                vol  = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Volume")

        with gr.Column(scale=3):
            text = gr.Textbox(label="Text to speak",
                              placeholder="Type or paste text here …",
                              lines=6, max_lines=12)
            btn  = gr.Button("🎧 Synthesise", variant="primary")
            audio_out = gr.Audio(label="Generated speech", type="filepath")

    # show/hide panels
    def switch_panel(e):
        return (gr.update(visible=e=="kokoro"),
                gr.update(visible=e=="veena"),
                gr.update(visible=e=="pyttsx3"))
    engine.change(switch_panel, inputs=engine,
                  outputs=[kokoro_box, veena_box, pyttsx3_box])

    # binding
    btn.click(synthesise,
              inputs=[text, engine, voice, speed, speaker,
                      temperature, top_p, rate, vol],
              outputs=audio_out)

    gr.Markdown("### Tips  \n"
                "- **Kokoro** – fastest, good quality English  \n"
                "- **Veena** – multilingual, GPU-friendly (4-bit)  \n"
                "- **pyttsx3** – offline fallback, any language  \n"
                "Audio is returned as 24 kHz WAV.")

# ------------------------------------------------------------------
# 6.  Launch
# ------------------------------------------------------------------
demo.launch()