|
|
import spaces |
|
|
import torch |
|
|
import gradio as gr |
|
|
import json |
|
|
import random |
|
|
import re |
|
|
from snac import SNAC |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = ( |
|
|
torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) |
|
|
else (torch.float16 if device == "cuda" else torch.float32) |
|
|
) |
|
|
SR = 24_000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading SNAC model...") |
|
|
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) |
|
|
|
|
|
model_name = "kenpath/svara-tts-v1" |
|
|
print(f"Loading Svara model: {model_name}") |
|
|
|
|
|
|
|
|
snapshot_download( |
|
|
repo_id=model_name, |
|
|
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"], |
|
|
ignore_patterns=["optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt"], |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device) |
|
|
model.eval() |
|
|
print(f"Svara model loaded to {device} with dtype={dtype}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open("examples.json", "r", encoding="utf-8") as f: |
|
|
EXAMPLES_DATA = json.load(f) |
|
|
print(f"Loaded {len(EXAMPLES_DATA)} examples from examples.json") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LANGUAGES = { |
|
|
"Assamese (অসমীয়া)": "Assamese", |
|
|
"Bengali (বাংলা)": "Bengali", |
|
|
"Bhojpuri (भोजपुरी)": "Bhojpuri", |
|
|
"Bodo (बर’/बड़ो)": "Bodo", |
|
|
"Chhattisgarhi (छत्तीसगढ़ी)": "Chhattisgarhi", |
|
|
"Dogri (डोगरी)": "Dogri", |
|
|
"Gujarati (ગુજરાતી)": "Gujarati", |
|
|
"Hindi (हिन्दी)": "Hindi", |
|
|
"Kannada (ಕನ್ನಡ)": "Kannada", |
|
|
"Maithili (मैथिली)": "Maithili", |
|
|
"Magahi (मगही)": "Magahi", |
|
|
"Malayalam (മലയാളം)": "Malayalam", |
|
|
"Marathi (मराठी)": "Marathi", |
|
|
"Nepali (नेपाली)": "Nepali", |
|
|
"Punjabi (ਪੰਜਾਬੀ)": "Punjabi", |
|
|
"Sanskrit (संस्कृतम्)": "Sanskrit", |
|
|
"Tamil (தமிழ்)": "Tamil", |
|
|
"Telugu (తెలుగు)": "Telugu", |
|
|
"English (Indian)": "English", |
|
|
} |
|
|
GENDERS = ["Male", "Female"] |
|
|
|
|
|
|
|
|
LANGUAGE_DISPLAY_MAP = {v: k for k, v in LANGUAGES.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_prompt(language, gender, text): |
|
|
lang_label = LANGUAGES.get(language, "English") |
|
|
|
|
|
|
|
|
|
|
|
style_match = re.search(r'<(neutral|formal|chat|clear|happy|surprise|sad|fear|anger|disgust)>', text) |
|
|
style_tag = f"<{style_match.group(1)}>" if style_match else "" |
|
|
|
|
|
|
|
|
text_without_tag = re.sub(r'<(neutral|formal|chat|clear|happy|surprise|sad|fear|anger|disgust)>', '', text).strip() |
|
|
|
|
|
|
|
|
tail = f" {style_tag}" if style_tag and style_tag != "<neutral>" else "" |
|
|
prompt = f"{lang_label} ({gender}): {text_without_tag}{tail}" |
|
|
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
|
|
|
|
|
start_token = torch.tensor([[128259]], dtype=torch.int64) |
|
|
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) |
|
|
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) |
|
|
attention_mask = torch.ones_like(modified_input_ids) |
|
|
return modified_input_ids.to(device), attention_mask.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_output(generated_ids): |
|
|
token_to_find, token_to_remove = 128257, 128258 |
|
|
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) |
|
|
cropped_tensor = generated_ids[:, token_indices[1][-1] + 1:] if len(token_indices[1]) > 0 else generated_ids |
|
|
processed_rows = [row[row != token_to_remove] for row in cropped_tensor] |
|
|
row = processed_rows[0] |
|
|
trimmed_row = row[: (row.size(0) // 7) * 7] |
|
|
trimmed_row = [int(t.item()) - 128266 for t in trimmed_row] |
|
|
return trimmed_row |
|
|
|
|
|
def redistribute_codes(code_list, snac_model): |
|
|
layer_1, layer_2, layer_3 = [], [], [] |
|
|
for i in range((len(code_list) + 1) // 7): |
|
|
base = 7 * i |
|
|
layer_1.append(code_list[base + 0]) |
|
|
layer_2.append(code_list[base + 1] - 4096) |
|
|
layer_3.append(code_list[base + 2] - (2 * 4096)) |
|
|
layer_3.append(code_list[base + 3] - (3 * 4096)) |
|
|
layer_2.append(code_list[base + 4] - (4 * 4096)) |
|
|
layer_3.append(code_list[base + 5] - (5 * 4096)) |
|
|
layer_3.append(code_list[base + 6] - (6 * 4096)) |
|
|
codes = [torch.tensor(x, device=device).unsqueeze(0) for x in [layer_1, layer_2, layer_3]] |
|
|
with torch.inference_mode(): |
|
|
audio = snac_model.decode(codes).detach().squeeze().cpu().numpy() |
|
|
return audio |
|
|
|
|
|
@spaces.GPU() |
|
|
def generate_speech(language, gender, text, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()): |
|
|
text = (text or "").strip() |
|
|
if not text: |
|
|
raise gr.Error("Please enter some text.") |
|
|
|
|
|
progress(0.2, "Preparing prompt…") |
|
|
input_ids, attention_mask = process_prompt(language, gender, text) |
|
|
|
|
|
progress(0.5, "Generating speech tokens…") |
|
|
with torch.inference_mode(): |
|
|
generated_ids = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
do_sample=True, |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
repetition_penalty=float(repetition_penalty), |
|
|
num_return_sequences=1, |
|
|
eos_token_id=128258, |
|
|
) |
|
|
|
|
|
progress(0.7, "Parsing output…") |
|
|
code_list = parse_output(generated_ids) |
|
|
if not code_list: |
|
|
raise gr.Error("No audio tokens were generated. Try increasing max tokens or temperature a bit.") |
|
|
|
|
|
progress(0.9, "Decoding audio…") |
|
|
audio = redistribute_codes(code_list, snac_model) |
|
|
return (SR, audio) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def randomize(): |
|
|
"""Select a random example and populate the fields""" |
|
|
example = random.choice(EXAMPLES_DATA) |
|
|
|
|
|
|
|
|
lang_display = LANGUAGE_DISPLAY_MAP.get(example["language"], "Hindi (हिन्दी)") |
|
|
gender = example["gender"] |
|
|
text = example["text"] |
|
|
|
|
|
|
|
|
return lang_display, gender, text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_theme = gr.themes.Soft( |
|
|
primary_hue="indigo", |
|
|
secondary_hue="blue", |
|
|
neutral_hue="slate", |
|
|
font=gr.themes.GoogleFont("Inter"), |
|
|
radius_size=gr.themes.sizes.radius_md, |
|
|
).set( |
|
|
button_primary_background_fill="*primary_500", |
|
|
button_primary_background_fill_hover="*primary_600", |
|
|
) |
|
|
|
|
|
with gr.Blocks(title="Svara Multilingual TTS", theme=custom_theme, css=".note{opacity:.85;font-size:.9em}") as demo: |
|
|
gr.Markdown(""" |
|
|
# svara-tts |
|
|
|
|
|
*An open multilingual TTS model for expressive, human-like speech across India's languages.* |
|
|
|
|
|
Visit [svara-tts](https://huggingface.co/kenpath/svara-tts-v1) for more details. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
with gr.Row(): |
|
|
lang = gr.Dropdown( |
|
|
choices=list(LANGUAGES.keys()), |
|
|
value="Hindi (हिन्दी)", |
|
|
label="Language", |
|
|
scale=2 |
|
|
) |
|
|
gender = gr.Dropdown( |
|
|
choices=GENDERS, |
|
|
value="Female", |
|
|
label="Gender", |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
text_input = gr.Textbox( |
|
|
label="Text to speak", |
|
|
placeholder="Type your text (add tags like <happy>, <sad> for emotion)…", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
randomize_btn = gr.Button("🎲 Randomize", variant="secondary", size="lg") |
|
|
|
|
|
with gr.Row(): |
|
|
submit = gr.Button("🎤 Generate Speech", variant="primary", scale=3, size="lg") |
|
|
clear = gr.Button("🗑️ Clear", variant="stop", scale=1) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
temperature = gr.Slider( |
|
|
minimum=0.3, |
|
|
maximum=1.2, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
label="Temperature", |
|
|
info="Higher = more expressive prosody; 0.6-0.9 for conversational, 0.9-1.2 for dramatic" |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
minimum=0.2, |
|
|
maximum=1.0, |
|
|
value=0.8, |
|
|
step=0.1, |
|
|
label="Top-p (nucleus sampling)", |
|
|
info="0.6-0.8 for natural prosody, 0.8-1.0 for expressive/dramatic" |
|
|
) |
|
|
repetition_penalty = gr.Slider( |
|
|
minimum=0.9, |
|
|
maximum=1.9, |
|
|
value=1.1, |
|
|
step=0.1, |
|
|
label="Repetition Penalty", |
|
|
info="≥1.1 recommended for stable generation; prevents loops" |
|
|
) |
|
|
max_new_tokens = gr.Slider( |
|
|
minimum=1000, |
|
|
maximum=4096, |
|
|
value=2048, |
|
|
step=100, |
|
|
label="Max New Tokens", |
|
|
info="Typical range: 900-1200 for most sentences" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
audio_output = gr.Audio( |
|
|
label="Generated Speech", |
|
|
type="numpy", |
|
|
autoplay=True |
|
|
) |
|
|
|
|
|
|
|
|
submit.click( |
|
|
fn=generate_speech, |
|
|
inputs=[lang, gender, text_input, temperature, top_p, repetition_penalty, max_new_tokens], |
|
|
outputs=audio_output, |
|
|
) |
|
|
|
|
|
randomize_btn.click( |
|
|
fn=randomize, |
|
|
inputs=[], |
|
|
outputs=[lang, gender, text_input], |
|
|
) |
|
|
|
|
|
def _clear(): |
|
|
|
|
|
return (None, None, 0.7, 0.8, 1.1, 2048) |
|
|
|
|
|
clear.click( |
|
|
_clear, |
|
|
inputs=[], |
|
|
outputs=[text_input, audio_output, temperature, top_p, repetition_penalty, max_new_tokens] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch(share=False) |