import gradio as gr import numpy as np import torch import tempfile import os from scipy.io.wavfile import write from transformers import ( SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan ) # ========================= # Model loading # ========================= checkpoint = "Chithekitale/chichewa_tts_norules" processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint) vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") # Make all keys consistent speaker_embeddings = { "SPK1": "spkemb/cmu_us_slt_arctic-wav-arctic_a0508.npy", "SPK2": "spkemb/cmu_us_rms_arctic-wav-arctic_b0353.npy", "SPK3": "spkemb/cmu_us_ksp_arctic-wav-arctic_b0087.npy", "SPK4": "spkemb/cmu_us_rms_arctic-wav-arctic_b0353.npy", "SPK5": "spkemb/cmu_us_slt_arctic-wav-arctic_a0508.npy", } SPEAKER_CHOICES = [ "SPK1 (female)", "SPK2 (male)", "SPK3 (male)", "SPK4 (male)", "SPK5 (female)" ] EXAMPLES = [ ["Ndapita, koma ndibweranso pompano.", "SPK1 (female)"], ["Koma apapa zikuoneka kuti ziyenda bwino.", "SPK2 (male)"], ["Ineyo ndikuona kuti sizizasithanso.", "SPK3 (male)"], ["Mwina kusogolo kuno anthu ena azalimba mtima, koma panopana ndakaika.", "SPK4 (male)"], ["Simungasankhe munthu oti bola linamukana.", "SPK5 (female)"], ["Kodi chimanga panopa chikugulisidwa zingati, kapena nanunso simukudziwa?", "SPK5 (female)"], ] SAMPLE_RATE = 16000 # ========================= # Helpers # ========================= def get_speaker_key(speaker_label: str) -> str: # "SPK1 (female)" -> "SPK1" return speaker_label.split()[0] def load_speaker_embedding(speaker: str) -> np.ndarray: speaker_key = get_speaker_key(speaker) if speaker_key not in speaker_embeddings: raise ValueError(f"Unknown speaker key: {speaker_key}") path = speaker_embeddings[speaker_key] try: speaker_embedding = np.load(path).astype(np.float32) except Exception as e: raise FileNotFoundError( f"Could not load speaker embedding file: {path}. Error: {e}" ) if speaker_embedding.ndim == 2: speaker_embedding = speaker_embedding.mean(axis=0) speaker_embedding = np.squeeze(speaker_embedding) if speaker_embedding.shape != (512,): raise ValueError( f"Unexpected speaker embedding shape after processing: " f"{speaker_embedding.shape}. Expected (512,)" ) return speaker_embedding def save_audio_to_wav(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> str: """ Save generated int16 audio to a temporary WAV file and return its path. """ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") temp_file.close() write(temp_file.name, sample_rate, audio) return temp_file.name # ========================= # Inference # ========================= def predict(text, speaker): try: if not text or len(text.strip()) == 0: return None, None, "Please enter some Chichewa text." inputs = processor(text=text, return_tensors="pt") input_ids = inputs["input_ids"][..., :model.config.max_text_positions] speaker_embedding = load_speaker_embedding(speaker) speaker_embedding = torch.tensor( speaker_embedding, dtype=torch.float32 ).unsqueeze(0) with torch.no_grad(): speech = model.generate_speech( input_ids, speaker_embedding, vocoder=vocoder ) speech = speech.cpu().numpy() # Normalize safely before int16 conversion max_val = np.max(np.abs(speech)) if max_val > 0: speech = speech / max_val speech = (speech * 32767).astype(np.int16) # Save WAV file for downloading wav_path = save_audio_to_wav(speech, SAMPLE_RATE) status = f"Generated speech successfully using speaker: {speaker}" return (SAMPLE_RATE, speech), wav_path, status except Exception as e: return None, None, f"Error during generation: {str(e)}" def clear_all(): return "", "SPK1 (female)", None, None, "Ready." # ========================= # UI # ========================= custom_css = """ .gradio-container { max-width: 1100px !important; margin: 0 auto; } .hero { text-align: center; padding: 10px 0 0 0; } .section-note { font-size: 0.95rem; opacity: 0.9; } """ with gr.Blocks(css=custom_css, title="Baseline Chichewa Speech Synthesis Demo") as demo: gr.HTML( """

Baseline Chichewa Synthesis

Enter Chichewa text, choose a speaker voice, and generate speech audio.

""" ) with gr.Row(): with gr.Column(scale=5): text_input = gr.Textbox( label="Input Text", placeholder="Type Chichewa text here...", lines=6 ) speaker_input = gr.Radio( label="Speaker Voice", choices=SPEAKER_CHOICES, value="SPK1 (female)" ) with gr.Row(): generate_btn = gr.Button("Generate Speech", variant="primary") clear_btn = gr.Button("Clear") status_box = gr.Textbox( label="System Status", value="Ready.", interactive=False ) with gr.Column(scale=5): audio_output = gr.Audio( label="Generated Speech", type="numpy", autoplay=False ) download_file = gr.File( label="Download Audio File" ) gr.Markdown("### Example Inputs") gr.Examples( examples=EXAMPLES, inputs=[text_input, speaker_input] ) generate_btn.click( fn=predict, inputs=[text_input, speaker_input], outputs=[audio_output, download_file, status_box], show_progress="full" ) clear_btn.click( fn=clear_all, inputs=[], outputs=[text_input, speaker_input, audio_output, download_file, status_box] ) demo.launch()