rajyalakshmijampani's picture
Update app.py
698ef0c verified
Raw
History Blame Contribute Delete
4.13 kB
import os
import torch
import gradio as gr
import soundfile as sf
import tempfile
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import login
# ==========================================
# 🔐 AUTHENTICATION
# ==========================================
HF_TOKEN = os.getenv("HF_TOKEN") # from HF Space secrets
if HF_TOKEN is not None:
login(token=HF_TOKEN)
# ==========================================
# CONFIG
# ==========================================
MODEL_NAME = "ai4bharat/vits_rasa_13"
device = "cuda" if torch.cuda.is_available() else "cpu"
# ==========================================
# LOAD MODEL
# ==========================================
model = AutoModel.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
token=HF_TOKEN
).to(device)
# Patch missing attribute
if not hasattr(model.config, "pad_token_id"):
model.config.pad_token_id = 0
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
token=HF_TOKEN
)
# ==========================================
# LANGUAGE → SPEAKER MAP
# ==========================================
LANG_SPEAKERS = {
"Assamese": [("ASM_F (0)", 0), ("ASM_M (1)", 1)],
"Bengali": [("BEN_F (2)", 2), ("BEN_M (3)", 3)],
"Bodo": [("BRX_F (4)", 4), ("BRX_M (5)", 5)],
"Dogri": [("DOI_F (6)", 6), ("DOI_M (7)", 7)],
"Kannada": [("KAN_F (8)", 8), ("KAN_M (9)", 9)],
"Maithili": [("MAI_M (10)", 10)],
"Malayalam": [("MAL_F (11)", 11)],
"Marathi": [("MAR_F (12)", 12), ("MAR_M (13)", 13)],
"Nepali": [("NEP_F (14)", 14)],
"Punjabi": [("PAN_F (15)", 15), ("PAN_M (16)", 16)],
"Sanskrit": [("SAN_M (17)", 17)],
"Tamil": [("TAM_F (18)", 18)],
"Telugu": [("TEL_F (19)", 19)],
}
# ==========================================
# INFERENCE
# ==========================================
def generate_tts(text, language, speaker_id, style_id):
if not text.strip():
return None
inputs = tokenizer(text=text, return_tensors="pt").to(device)
with torch.inference_mode():
outputs = model(
inputs["input_ids"],
speaker_id=speaker_id,
emotion_id=style_id
)
audio = outputs.waveform.squeeze().cpu().numpy()
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
sf.write(temp_file.name, audio, model.config.sampling_rate)
return temp_file.name
# ==========================================
# UI
# ==========================================
with gr.Blocks() as demo:
gr.Markdown("## 🎙️ ai4bharat/vits_rasa_13")
text_input = gr.Textbox(label="Enter Text", lines=6)
language_dropdown = gr.Dropdown(
choices=list(LANG_SPEAKERS.keys()),
value="Punjabi",
label="Language"
)
speaker_dropdown = gr.Dropdown(
choices=LANG_SPEAKERS["Punjabi"],
value=16,
label="Speaker"
)
style_dropdown = gr.Dropdown(
choices=[
("ALEXA (0)", 0),
("ANGER (1)", 1),
("BB (2)", 2),
("BOOK (3)", 3),
("CONV (4)", 4),
("DIGI (5)", 5),
("DISGUST (6)", 6),
("FEAR (7)", 7),
("HAPPY (8)", 8),
("NEWS (10)", 10),
("SAD (12)", 12),
("SURPRISE (14)", 14),
("UMANG (15)", 15),
("WIKI (16)", 16),
],
value=0,
label="Style / Emotion"
)
generate_btn = gr.Button("Generate Speech")
audio_output = gr.Audio(label="Output Audio", type="filepath")
def update_speakers(selected_language):
speakers = LANG_SPEAKERS[selected_language]
return gr.Dropdown(choices=speakers, value=speakers[0][1])
language_dropdown.change(
fn=update_speakers,
inputs=language_dropdown,
outputs=speaker_dropdown
)
generate_btn.click(
fn=generate_tts,
inputs=[text_input, language_dropdown, speaker_dropdown, style_dropdown],
outputs=audio_output
)
demo.launch()