Spaces:
Running
Running
| import os | |
| import tempfile | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import scipy.io.wavfile as wavfile | |
| from pydub import AudioSegment | |
| from transformers import VitsModel, AutoTokenizer | |
| # ---------- Configuration -------------------------------------------------- | |
| # Define available TTS models here. Add new entries as needed. | |
| TTS_MODELS = { | |
| "Swahili": { | |
| "tokenizer": "FarmerlineML/swahili-tts-2025", | |
| "checkpoint": "FarmerlineML/Swahili-tts-2025_part4" | |
| }, | |
| "Krio": { | |
| "tokenizer": "FarmerlineML/Krio-TTS", | |
| "checkpoint": "FarmerlineML/Krio-TTS" | |
| }, | |
| "Ewe": { | |
| "tokenizer": "FarmerlineML/Ewe-tts-2025", | |
| "checkpoint": "FarmerlineML/Ewe-tts-2025" | |
| }, | |
| } | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------- Load all models & tokenizers ----------------------------------- | |
| models = {} | |
| tokenizers = {} | |
| for name, paths in TTS_MODELS.items(): | |
| print(f"Loading {name} model...") | |
| model = VitsModel.from_pretrained(paths["checkpoint"]).to(device) | |
| model.eval() | |
| # Apply clear-speech inference parameters (tweak per model if desired) | |
| model.noise_scale = 0.8 | |
| model.noise_scale_duration = 0.667 | |
| model.speaking_rate = 0.75 | |
| models[name] = model | |
| tokenizers[name] = AutoTokenizer.from_pretrained(paths["tokenizer"]) | |
| # ---------- Utility: WAV ➔ MP3 Conversion ----------------------------------- | |
| def _wav_to_mp3(wave_np: np.ndarray, sr: int) -> str: | |
| """Convert int16 numpy waveform to an MP3 temp file, return its path.""" | |
| # Ensure int16 for pydub | |
| if wave_np.dtype != np.int16: | |
| wave_np = (wave_np * 32767).astype(np.int16) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf: | |
| wavfile.write(tf.name, sr, wave_np) | |
| wav_path = tf.name | |
| mp3_path = wav_path.replace(".wav", ".mp3") | |
| AudioSegment.from_wav(wav_path).export(mp3_path, format="mp3", bitrate="64k") | |
| os.remove(wav_path) | |
| return mp3_path | |
| # ---------- TTS Generation --------------------------------------------------- | |
| def tts_generate(model_name: str, text: str): | |
| """Generate speech for `text` using the selected model.""" | |
| if not text: | |
| return None | |
| model = models[model_name] | |
| tokenizer = tokenizers[model_name] | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| wave = model(**inputs).waveform[0].cpu().numpy() | |
| return _wav_to_mp3(wave, model.config.sampling_rate) | |
| # ---------- Gradio Interface ------------------------------------------------ | |
| examples = [ | |
| ["Swahili", "zao kusaidia kuondoa umaskini na kujenga kampeni za mwamko wa virusi vya ukimwi amezitembelea"], | |
| ["Swahili", "Kidole hiki ni tofauti na vidole vingine kwa sababu mwelekeo wake ni wa pekee."], | |
| ["Swahili", "Tafadhali hakikisha umefunga mlango kabla ya kuondoka."], | |
| ["Krio", "Wetin na yu nem?"], | |
| ["Krio", "aw yu de du"], | |
| ["Krio", "A de go skul"], | |
| ] | |
| demo = gr.Interface( | |
| fn=tts_generate, | |
| inputs=[ | |
| gr.Dropdown(choices=list(TTS_MODELS.keys()), value="Swahili", label="Choose TTS Model"), | |
| gr.Textbox(lines=3, placeholder="Enter text here", label="Input Text") | |
| ], | |
| outputs=gr.Audio(type="filepath", label="Audio", autoplay=True), | |
| title="Multi‐Model Text-to-Speech", | |
| description=( | |
| "Select a TTS model from the dropdown and enter text to generate speech." | |
| ), | |
| examples=examples, | |
| cache_examples=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |