import spaces import torch import gradio as gr import json import random import re from snac import SNAC from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import snapshot_download # -------------------------- # Device / dtype # -------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" dtype = ( torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else (torch.float16 if device == "cuda" else torch.float32) ) SR = 24_000 # SNAC sample rate # -------------------------- # Load models # -------------------------- print("Loading SNAC model...") snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) model_name = "kenpath/svara-tts-v1" print(f"Loading Svara model: {model_name}") # Prefetch safetensors to speed up first run snapshot_download( repo_id=model_name, allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"], ignore_patterns=["optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt"], ) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device) model.eval() print(f"Svara model loaded to {device} with dtype={dtype}") # -------------------------- # Load examples from JSON # -------------------------- with open("examples.json", "r", encoding="utf-8") as f: EXAMPLES_DATA = json.load(f) print(f"Loaded {len(EXAMPLES_DATA)} examples from examples.json") # -------------------------- # Languages & genders (19 total: 18 Indic + English) # -------------------------- LANGUAGES = { "Assamese (অসমীয়া)": "Assamese", "Bengali (বাংলা)": "Bengali", "Bhojpuri (भोजपुरी)": "Bhojpuri", "Bodo (बर’/बड़ो)": "Bodo", "Chhattisgarhi (छत्तीसगढ़ी)": "Chhattisgarhi", "Dogri (डोगरी)": "Dogri", "Gujarati (ગુજરાતી)": "Gujarati", "Hindi (हिन्दी)": "Hindi", "Kannada (ಕನ್ನಡ)": "Kannada", "Maithili (मैथिली)": "Maithili", "Magahi (मगही)": "Magahi", "Malayalam (മലയാളം)": "Malayalam", "Marathi (मराठी)": "Marathi", "Nepali (नेपाली)": "Nepali", "Punjabi (ਪੰਜਾਬੀ)": "Punjabi", "Sanskrit (संस्कृतम्)": "Sanskrit", "Tamil (தமிழ்)": "Tamil", "Telugu (తెలుగు)": "Telugu", "English (Indian)": "English", } GENDERS = ["Male", "Female"] # Create reverse mapping: simple name -> display format LANGUAGE_DISPLAY_MAP = {v: k for k, v in LANGUAGES.items()} # -------------------------- # Prompt preparation (keep your IDs/format) # -------------------------- def process_prompt(language, gender, text): lang_label = LANGUAGES.get(language, "English") # Extract style tag from text (if present) # Tags are like , , , etc. style_match = re.search(r'<(neutral|formal|chat|clear|happy|surprise|sad|fear|anger|disgust)>', text) style_tag = f"<{style_match.group(1)}>" if style_match else "" # Remove the tag from text for processing text_without_tag = re.sub(r'<(neutral|formal|chat|clear|happy|surprise|sad|fear|anger|disgust)>', '', text).strip() # Only append a style if it's present and NOT neutral tail = f" {style_tag}" if style_tag and style_tag != "" else "" prompt = f"{lang_label} ({gender}): {text_without_tag}{tail}" input_ids = tokenizer(prompt, return_tensors="pt").input_ids # Special tokens (your working IDs) start_token = torch.tensor([[128259]], dtype=torch.int64) # end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # , modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) attention_mask = torch.ones_like(modified_input_ids) return modified_input_ids.to(device), attention_mask.to(device) # -------------------------- # Parse + decode (original logic) # -------------------------- def parse_output(generated_ids): token_to_find, token_to_remove = 128257, 128258 # , token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) cropped_tensor = generated_ids[:, token_indices[1][-1] + 1:] if len(token_indices[1]) > 0 else generated_ids processed_rows = [row[row != token_to_remove] for row in cropped_tensor] row = processed_rows[0] trimmed_row = row[: (row.size(0) // 7) * 7] trimmed_row = [int(t.item()) - 128266 for t in trimmed_row] return trimmed_row def redistribute_codes(code_list, snac_model): layer_1, layer_2, layer_3 = [], [], [] for i in range((len(code_list) + 1) // 7): base = 7 * i layer_1.append(code_list[base + 0]) layer_2.append(code_list[base + 1] - 4096) layer_3.append(code_list[base + 2] - (2 * 4096)) layer_3.append(code_list[base + 3] - (3 * 4096)) layer_2.append(code_list[base + 4] - (4 * 4096)) layer_3.append(code_list[base + 5] - (5 * 4096)) layer_3.append(code_list[base + 6] - (6 * 4096)) codes = [torch.tensor(x, device=device).unsqueeze(0) for x in [layer_1, layer_2, layer_3]] with torch.inference_mode(): audio = snac_model.decode(codes).detach().squeeze().cpu().numpy() return audio @spaces.GPU() def generate_speech(language, gender, text, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()): text = (text or "").strip() if not text: raise gr.Error("Please enter some text.") progress(0.2, "Preparing prompt…") input_ids, attention_mask = process_prompt(language, gender, text) progress(0.5, "Generating speech tokens…") with torch.inference_mode(): generated_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=float(temperature), top_p=float(top_p), repetition_penalty=float(repetition_penalty), num_return_sequences=1, eos_token_id=128258, # keep your eos id ) progress(0.7, "Parsing output…") code_list = parse_output(generated_ids) if not code_list: raise gr.Error("No audio tokens were generated. Try increasing max tokens or temperature a bit.") progress(0.9, "Decoding audio…") audio = redistribute_codes(code_list, snac_model) return (SR, audio) # -------------------------- # Randomize # -------------------------- def randomize(): """Select a random example and populate the fields""" example = random.choice(EXAMPLES_DATA) # Map simple language name to display format lang_display = LANGUAGE_DISPLAY_MAP.get(example["language"], "Hindi (हिन्दी)") gender = example["gender"] text = example["text"] # Return values to populate UI fields only return lang_display, gender, text # -------------------------- # UI # -------------------------- custom_theme = gr.themes.Soft( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", font=gr.themes.GoogleFont("Inter"), radius_size=gr.themes.sizes.radius_md, ).set( button_primary_background_fill="*primary_500", button_primary_background_fill_hover="*primary_600", ) with gr.Blocks(title="Svara Multilingual TTS", theme=custom_theme, css=".note{opacity:.85;font-size:.9em}") as demo: gr.Markdown(""" # svara-tts *An open multilingual TTS model for expressive, human-like speech across India's languages.* Visit [svara-tts](https://huggingface.co/kenpath/svara-tts-v1) for more details. """) with gr.Row(): with gr.Column(scale=3): with gr.Row(): lang = gr.Dropdown( choices=list(LANGUAGES.keys()), value="Hindi (हिन्दी)", label="Language", scale=2 ) gender = gr.Dropdown( choices=GENDERS, value="Female", label="Gender", scale=1 ) text_input = gr.Textbox( label="Text to speak", placeholder="Type your text (add tags like , for emotion)…", lines=5 ) with gr.Row(): randomize_btn = gr.Button("🎲 Randomize", variant="secondary", size="lg") with gr.Row(): submit = gr.Button("🎤 Generate Speech", variant="primary", scale=3, size="lg") clear = gr.Button("🗑️ Clear", variant="stop", scale=1) with gr.Accordion("Advanced Settings", open=False): temperature = gr.Slider( minimum=0.3, maximum=1.2, value=0.7, step=0.1, label="Temperature", info="Higher = more expressive prosody; 0.6-0.9 for conversational, 0.9-1.2 for dramatic" ) top_p = gr.Slider( minimum=0.2, maximum=1.0, value=0.8, step=0.1, label="Top-p (nucleus sampling)", info="0.6-0.8 for natural prosody, 0.8-1.0 for expressive/dramatic" ) repetition_penalty = gr.Slider( minimum=0.9, maximum=1.9, value=1.1, step=0.1, label="Repetition Penalty", info="≥1.1 recommended for stable generation; prevents loops" ) max_new_tokens = gr.Slider( minimum=1000, maximum=4096, value=2048, step=100, label="Max New Tokens", info="Typical range: 900-1200 for most sentences" ) with gr.Column(scale=2): audio_output = gr.Audio( label="Generated Speech", type="numpy", autoplay=True ) # Event handlers submit.click( fn=generate_speech, inputs=[lang, gender, text_input, temperature, top_p, repetition_penalty, max_new_tokens], outputs=audio_output, ) randomize_btn.click( fn=randomize, inputs=[], outputs=[lang, gender, text_input], ) def _clear(): # Reset text, audio, and sliders to defaults return (None, None, 0.7, 0.8, 1.1, 2048) clear.click( _clear, inputs=[], outputs=[text_input, audio_output, temperature, top_p, repetition_penalty, max_new_tokens] ) if __name__ == "__main__": demo.queue().launch(share=False)