import traceback import soundfile as sf import torch import numpy as np from transformers import ( SeamlessM4TModel, AutoProcessor, pipeline, VitsModel, AutoTokenizer ) import gradio as gr import resampy import tempfile import subprocess # --- Load SeamlessM4T model for ASR and translation --- try: model_id = "facebook/seamless-m4t-v2-large" processor = AutoProcessor.from_pretrained(model_id) model = SeamlessM4TModel.from_pretrained(model_id).to("cpu") print("[INFO] SeamlessM4T model loaded for ASR and translation.") except Exception as e: print("[ERROR] Failed to load SeamlessM4T model:", e) traceback.print_exc() model = None processor = None # --- Load chat model --- try: chat_model = pipeline("text2text-generation", model="google/flan-t5-base") print("[INFO] Chat model loaded successfully.") except Exception as e: print("[ERROR] Failed to load chat model:", e) traceback.print_exc() chat_model = None # --- Load TTS model (Facebook MMS for Amharic) --- try: tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-amh") tts_model = VitsModel.from_pretrained("facebook/mms-tts-amh").to("cpu") print("[INFO] Facebook MMS TTS model for Amharic loaded successfully.") except Exception as e: print("[ERROR] Failed to load Facebook MMS TTS model:", e) traceback.print_exc() tts_tokenizer = None tts_model = None # --- Romanization helper --- def romanize(text): try: result = subprocess.run(["uroman"], input=text.encode("utf-8"), stdout=subprocess.PIPE) return result.stdout.decode("utf-8").strip() except Exception as e: print("[ERROR] Romanization failed:", e) return text # fallback # --- ASR with SeamlessM4T --- def transcribe_amharic(audio_file): if model is None or processor is None: return "ASR Model loading failed" try: audio, sr = sf.read(audio_file) if audio.ndim > 1: audio = audio.mean(axis=1) audio = resampy.resample(audio, sr, 16000) # Direct Amharic transcription inputs = processor(audio=audio, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): generated_ids = model.generate( **inputs, tgt_lang="amh", generate_speech=False ) transcription = processor.decode(generated_ids[0], skip_special_tokens=True) return transcription.strip() except Exception as e: print("[ERROR] ASR transcription failed:", e) traceback.print_exc() return f"ASR failed: {str(e)[:50]}..." # --- Translation with SeamlessM4T (Amharic to English) --- def translate_am_to_en(amharic_text): if model is None or processor is None: return "Translation model not loaded" try: # Translate Amharic to English using SeamlessM4T text_inputs = processor(text=amharic_text, src_lang="amh", return_tensors="pt") with torch.no_grad(): output_tokens = model.generate( **text_inputs, tgt_lang="eng", generate_speech=False ) translated_text = processor.decode(output_tokens[0], skip_special_tokens=True) return translated_text.strip() except Exception as e: print("[ERROR] Translation failed:", e) traceback.print_exc() return f"Translation failed: {str(e)[:50]}..." # --- Back translation with SeamlessM4T (English to Amharic) --- def back_translate_en_to_am(en_text): if model is None or processor is None: return "Back translation model not loaded" try: # Translate English back to Amharic using SeamlessM4T text_inputs = processor(text=en_text, src_lang="eng", return_tensors="pt") with torch.no_grad(): output_tokens = model.generate( **text_inputs, tgt_lang="amh", generate_speech=False ) am_response = processor.decode(output_tokens[0], skip_special_tokens=True) return am_response.strip() except Exception as e: print("[ERROR] Back translation failed:", e) traceback.print_exc() return f"Back translation failed: {str(e)[:50]}..." # --- Chat response --- def generate_chat_response(text): if chat_model is None: return "Chat model not loaded" try: # Add context to make responses more meaningful prompt = f"Respond to this in a helpful and conversational way: {text}" response = chat_model(prompt, max_length=150, num_beams=5, temperature=0.7, do_sample=True)[0]['generated_text'] return response.strip() except Exception as e: print("[ERROR] Chat generation failed:", e) return f"Chat failed: {str(e)[:50]}..." # --- TTS with Facebook MMS --- def generate_tts(text): if tts_model is None or tts_tokenizer is None: print("[ERROR] TTS model not loaded") return None try: if not text.strip(): return None # Tokenize text and generate speech inputs = tts_tokenizer(text, return_tensors="pt") with torch.no_grad(): output = tts_model(**inputs) speech = output.waveform # Convert to numpy and normalize audio_data = speech.cpu().numpy().squeeze() max_val = np.max(np.abs(audio_data)) if max_val > 0: audio_data = audio_data / max_val return audio_data, tts_model.config.sampling_rate except Exception as e: print("[ERROR] MMS TTS generation failed:", e) traceback.print_exc() return None # --- Alternative TTS using gTTS (fallback) --- def generate_tts_gtts(text): try: from gtts import gTTS import io tts = gTTS(text=text, lang='am') fp = io.BytesIO() tts.write_to_fp(fp) fp.seek(0) # Convert to numpy array for consistency audio, sr = sf.read(fp) return audio, sr except Exception as e: print("[ERROR] gTTS failed:", e) return None # --- Simple audio fallback --- def generate_simple_audio(text): try: sampling_rate = 22050 duration = min(3.0, max(1.0, len(text)/10)) t = np.linspace(0, duration, int(sampling_rate*duration), endpoint=False) frequency = 300 + (hash(text) % 300) audio_data = 0.5 * np.sin(2 * np.pi * frequency * t) return audio_data, sampling_rate except Exception as e: print("[ERROR] Simple audio generation failed:", e) return None # --- Create WAV file --- def create_wav_file(audio_array, sample_rate): try: if audio_array is None: return None if audio_array.ndim > 1: audio_array = audio_array.flatten() temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) sf.write(temp_file.name, audio_array, sample_rate) return temp_file.name except Exception as e: print("[ERROR] WAV file creation failed:", e) traceback.print_exc() return None # --- Assistant pipeline --- def assistant_pipeline(audio): if not audio: return "No audio", "", "", "", None # Step 1: ASR with SeamlessM4T asr_result = transcribe_amharic(audio) print(f"ASR Result: {asr_result}") # Step 2: Translation with SeamlessM4T en_text = translate_am_to_en(asr_result) print(f"English Translation: {en_text}") # Step 3: Chat response en_response = generate_chat_response(en_text) print(f"Chat Response: {en_response}") # Step 4: Back translation with SeamlessM4T am_response = back_translate_en_to_am(en_response) print(f"Amharic Response: {am_response}") # Step 5: TTS audio_file_path = None if am_response and not am_response.startswith("Back translation failed"): # Try MMS TTS first tts_result = generate_tts(am_response) # If MMS TTS fails, try gTTS if tts_result is None: print("[INFO] Trying gTTS fallback") tts_result = generate_tts_gtts(am_response) # If both TTS methods fail, use simple audio if tts_result is None: print("[INFO] Using simple audio fallback") tts_result = generate_simple_audio(am_response) if tts_result is not None: audio_data, sample_rate = tts_result audio_file_path = create_wav_file(audio_data, sample_rate) print(f"Audio generated successfully: {audio_file_path}") return asr_result, en_text, en_response, am_response, audio_file_path # --- Gradio UI --- with gr.Blocks(title="🌍 Local Language AI Assistant") as demo: gr.Markdown("# 🌍 Local Language AI Assistant") gr.Markdown("🎙️ Speak **or upload** Amharic audio and get AI responses with synthesized Amharic speech!") with gr.Row(): audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="🎤 Record or Upload your voice") submit_btn = gr.Button("Process", variant="primary") with gr.Row(): with gr.Column(): asr_output = gr.Textbox(label="ASR (Amharic text)") en_translation = gr.Textbox(label="Translated to English") en_response = gr.Textbox(label="Model Response (English)") am_response = gr.Textbox(label="Back Translated (Amharic)") audio_output = gr.Audio(label="Amharic TTS Output", type="filepath") submit_btn.click( fn=assistant_pipeline, inputs=audio_input, outputs=[asr_output, en_translation, en_response, am_response, audio_output] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)