Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| import soundfile as sf | |
| import tempfile | |
| from transformers import AutoModel, AutoTokenizer | |
| from huggingface_hub import login | |
| # ========================================== | |
| # 🔐 AUTHENTICATION | |
| # ========================================== | |
| HF_TOKEN = os.getenv("HF_TOKEN") # from HF Space secrets | |
| if HF_TOKEN is not None: | |
| login(token=HF_TOKEN) | |
| # ========================================== | |
| # CONFIG | |
| # ========================================== | |
| MODEL_NAME = "ai4bharat/vits_rasa_13" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ========================================== | |
| # LOAD MODEL | |
| # ========================================== | |
| model = AutoModel.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| token=HF_TOKEN | |
| ).to(device) | |
| # Patch missing attribute | |
| if not hasattr(model.config, "pad_token_id"): | |
| model.config.pad_token_id = 0 | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| token=HF_TOKEN | |
| ) | |
| # ========================================== | |
| # LANGUAGE → SPEAKER MAP | |
| # ========================================== | |
| LANG_SPEAKERS = { | |
| "Assamese": [("ASM_F (0)", 0), ("ASM_M (1)", 1)], | |
| "Bengali": [("BEN_F (2)", 2), ("BEN_M (3)", 3)], | |
| "Bodo": [("BRX_F (4)", 4), ("BRX_M (5)", 5)], | |
| "Dogri": [("DOI_F (6)", 6), ("DOI_M (7)", 7)], | |
| "Kannada": [("KAN_F (8)", 8), ("KAN_M (9)", 9)], | |
| "Maithili": [("MAI_M (10)", 10)], | |
| "Malayalam": [("MAL_F (11)", 11)], | |
| "Marathi": [("MAR_F (12)", 12), ("MAR_M (13)", 13)], | |
| "Nepali": [("NEP_F (14)", 14)], | |
| "Punjabi": [("PAN_F (15)", 15), ("PAN_M (16)", 16)], | |
| "Sanskrit": [("SAN_M (17)", 17)], | |
| "Tamil": [("TAM_F (18)", 18)], | |
| "Telugu": [("TEL_F (19)", 19)], | |
| } | |
| # ========================================== | |
| # INFERENCE | |
| # ========================================== | |
| def generate_tts(text, language, speaker_id, style_id): | |
| if not text.strip(): | |
| return None | |
| inputs = tokenizer(text=text, return_tensors="pt").to(device) | |
| with torch.inference_mode(): | |
| outputs = model( | |
| inputs["input_ids"], | |
| speaker_id=speaker_id, | |
| emotion_id=style_id | |
| ) | |
| audio = outputs.waveform.squeeze().cpu().numpy() | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| sf.write(temp_file.name, audio, model.config.sampling_rate) | |
| return temp_file.name | |
| # ========================================== | |
| # UI | |
| # ========================================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎙️ ai4bharat/vits_rasa_13") | |
| text_input = gr.Textbox(label="Enter Text", lines=6) | |
| language_dropdown = gr.Dropdown( | |
| choices=list(LANG_SPEAKERS.keys()), | |
| value="Punjabi", | |
| label="Language" | |
| ) | |
| speaker_dropdown = gr.Dropdown( | |
| choices=LANG_SPEAKERS["Punjabi"], | |
| value=16, | |
| label="Speaker" | |
| ) | |
| style_dropdown = gr.Dropdown( | |
| choices=[ | |
| ("ALEXA (0)", 0), | |
| ("ANGER (1)", 1), | |
| ("BB (2)", 2), | |
| ("BOOK (3)", 3), | |
| ("CONV (4)", 4), | |
| ("DIGI (5)", 5), | |
| ("DISGUST (6)", 6), | |
| ("FEAR (7)", 7), | |
| ("HAPPY (8)", 8), | |
| ("NEWS (10)", 10), | |
| ("SAD (12)", 12), | |
| ("SURPRISE (14)", 14), | |
| ("UMANG (15)", 15), | |
| ("WIKI (16)", 16), | |
| ], | |
| value=0, | |
| label="Style / Emotion" | |
| ) | |
| generate_btn = gr.Button("Generate Speech") | |
| audio_output = gr.Audio(label="Output Audio", type="filepath") | |
| def update_speakers(selected_language): | |
| speakers = LANG_SPEAKERS[selected_language] | |
| return gr.Dropdown(choices=speakers, value=speakers[0][1]) | |
| language_dropdown.change( | |
| fn=update_speakers, | |
| inputs=language_dropdown, | |
| outputs=speaker_dropdown | |
| ) | |
| generate_btn.click( | |
| fn=generate_tts, | |
| inputs=[text_input, language_dropdown, speaker_dropdown, style_dropdown], | |
| outputs=audio_output | |
| ) | |
| demo.launch() | |