Spaces:
Sleeping
Sleeping
Minte
fix: refactor model loading and enhance ASR and translation functionality with SeamlessM4T integration
cb4630e
| import traceback | |
| import soundfile as sf | |
| import torch | |
| import numpy as np | |
| from transformers import ( | |
| SeamlessM4TModel, AutoProcessor, | |
| pipeline, VitsModel, AutoTokenizer | |
| ) | |
| import gradio as gr | |
| import resampy | |
| import tempfile | |
| import subprocess | |
| # --- Load SeamlessM4T model for ASR and translation --- | |
| try: | |
| model_id = "facebook/seamless-m4t-v2-large" | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = SeamlessM4TModel.from_pretrained(model_id).to("cpu") | |
| print("[INFO] SeamlessM4T model loaded for ASR and translation.") | |
| except Exception as e: | |
| print("[ERROR] Failed to load SeamlessM4T model:", e) | |
| traceback.print_exc() | |
| model = None | |
| processor = None | |
| # --- Load chat model --- | |
| try: | |
| chat_model = pipeline("text2text-generation", model="google/flan-t5-base") | |
| print("[INFO] Chat model loaded successfully.") | |
| except Exception as e: | |
| print("[ERROR] Failed to load chat model:", e) | |
| traceback.print_exc() | |
| chat_model = None | |
| # --- Load TTS model (Facebook MMS for Amharic) --- | |
| try: | |
| tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-amh") | |
| tts_model = VitsModel.from_pretrained("facebook/mms-tts-amh").to("cpu") | |
| print("[INFO] Facebook MMS TTS model for Amharic loaded successfully.") | |
| except Exception as e: | |
| print("[ERROR] Failed to load Facebook MMS TTS model:", e) | |
| traceback.print_exc() | |
| tts_tokenizer = None | |
| tts_model = None | |
| # --- Romanization helper --- | |
| def romanize(text): | |
| try: | |
| result = subprocess.run(["uroman"], input=text.encode("utf-8"), stdout=subprocess.PIPE) | |
| return result.stdout.decode("utf-8").strip() | |
| except Exception as e: | |
| print("[ERROR] Romanization failed:", e) | |
| return text # fallback | |
| # --- ASR with SeamlessM4T --- | |
| def transcribe_amharic(audio_file): | |
| if model is None or processor is None: | |
| return "ASR Model loading failed" | |
| try: | |
| audio, sr = sf.read(audio_file) | |
| if audio.ndim > 1: | |
| audio = audio.mean(axis=1) | |
| audio = resampy.resample(audio, sr, 16000) | |
| # Direct Amharic transcription | |
| inputs = processor(audio=audio, sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| tgt_lang="amh", | |
| generate_speech=False | |
| ) | |
| transcription = processor.decode(generated_ids[0], skip_special_tokens=True) | |
| return transcription.strip() | |
| except Exception as e: | |
| print("[ERROR] ASR transcription failed:", e) | |
| traceback.print_exc() | |
| return f"ASR failed: {str(e)[:50]}..." | |
| # --- Translation with SeamlessM4T (Amharic to English) --- | |
| def translate_am_to_en(amharic_text): | |
| if model is None or processor is None: | |
| return "Translation model not loaded" | |
| try: | |
| # Translate Amharic to English using SeamlessM4T | |
| text_inputs = processor(text=amharic_text, src_lang="amh", return_tensors="pt") | |
| with torch.no_grad(): | |
| output_tokens = model.generate( | |
| **text_inputs, | |
| tgt_lang="eng", | |
| generate_speech=False | |
| ) | |
| translated_text = processor.decode(output_tokens[0], skip_special_tokens=True) | |
| return translated_text.strip() | |
| except Exception as e: | |
| print("[ERROR] Translation failed:", e) | |
| traceback.print_exc() | |
| return f"Translation failed: {str(e)[:50]}..." | |
| # --- Back translation with SeamlessM4T (English to Amharic) --- | |
| def back_translate_en_to_am(en_text): | |
| if model is None or processor is None: | |
| return "Back translation model not loaded" | |
| try: | |
| # Translate English back to Amharic using SeamlessM4T | |
| text_inputs = processor(text=en_text, src_lang="eng", return_tensors="pt") | |
| with torch.no_grad(): | |
| output_tokens = model.generate( | |
| **text_inputs, | |
| tgt_lang="amh", | |
| generate_speech=False | |
| ) | |
| am_response = processor.decode(output_tokens[0], skip_special_tokens=True) | |
| return am_response.strip() | |
| except Exception as e: | |
| print("[ERROR] Back translation failed:", e) | |
| traceback.print_exc() | |
| return f"Back translation failed: {str(e)[:50]}..." | |
| # --- Chat response --- | |
| def generate_chat_response(text): | |
| if chat_model is None: | |
| return "Chat model not loaded" | |
| try: | |
| # Add context to make responses more meaningful | |
| prompt = f"Respond to this in a helpful and conversational way: {text}" | |
| response = chat_model(prompt, max_length=150, num_beams=5, temperature=0.7, do_sample=True)[0]['generated_text'] | |
| return response.strip() | |
| except Exception as e: | |
| print("[ERROR] Chat generation failed:", e) | |
| return f"Chat failed: {str(e)[:50]}..." | |
| # --- TTS with Facebook MMS --- | |
| def generate_tts(text): | |
| if tts_model is None or tts_tokenizer is None: | |
| print("[ERROR] TTS model not loaded") | |
| return None | |
| try: | |
| if not text.strip(): | |
| return None | |
| # Tokenize text and generate speech | |
| inputs = tts_tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = tts_model(**inputs) | |
| speech = output.waveform | |
| # Convert to numpy and normalize | |
| audio_data = speech.cpu().numpy().squeeze() | |
| max_val = np.max(np.abs(audio_data)) | |
| if max_val > 0: | |
| audio_data = audio_data / max_val | |
| return audio_data, tts_model.config.sampling_rate | |
| except Exception as e: | |
| print("[ERROR] MMS TTS generation failed:", e) | |
| traceback.print_exc() | |
| return None | |
| # --- Alternative TTS using gTTS (fallback) --- | |
| def generate_tts_gtts(text): | |
| try: | |
| from gtts import gTTS | |
| import io | |
| tts = gTTS(text=text, lang='am') | |
| fp = io.BytesIO() | |
| tts.write_to_fp(fp) | |
| fp.seek(0) | |
| # Convert to numpy array for consistency | |
| audio, sr = sf.read(fp) | |
| return audio, sr | |
| except Exception as e: | |
| print("[ERROR] gTTS failed:", e) | |
| return None | |
| # --- Simple audio fallback --- | |
| def generate_simple_audio(text): | |
| try: | |
| sampling_rate = 22050 | |
| duration = min(3.0, max(1.0, len(text)/10)) | |
| t = np.linspace(0, duration, int(sampling_rate*duration), endpoint=False) | |
| frequency = 300 + (hash(text) % 300) | |
| audio_data = 0.5 * np.sin(2 * np.pi * frequency * t) | |
| return audio_data, sampling_rate | |
| except Exception as e: | |
| print("[ERROR] Simple audio generation failed:", e) | |
| return None | |
| # --- Create WAV file --- | |
| def create_wav_file(audio_array, sample_rate): | |
| try: | |
| if audio_array is None: | |
| return None | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.flatten() | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| sf.write(temp_file.name, audio_array, sample_rate) | |
| return temp_file.name | |
| except Exception as e: | |
| print("[ERROR] WAV file creation failed:", e) | |
| traceback.print_exc() | |
| return None | |
| # --- Assistant pipeline --- | |
| def assistant_pipeline(audio): | |
| if not audio: | |
| return "No audio", "", "", "", None | |
| # Step 1: ASR with SeamlessM4T | |
| asr_result = transcribe_amharic(audio) | |
| print(f"ASR Result: {asr_result}") | |
| # Step 2: Translation with SeamlessM4T | |
| en_text = translate_am_to_en(asr_result) | |
| print(f"English Translation: {en_text}") | |
| # Step 3: Chat response | |
| en_response = generate_chat_response(en_text) | |
| print(f"Chat Response: {en_response}") | |
| # Step 4: Back translation with SeamlessM4T | |
| am_response = back_translate_en_to_am(en_response) | |
| print(f"Amharic Response: {am_response}") | |
| # Step 5: TTS | |
| audio_file_path = None | |
| if am_response and not am_response.startswith("Back translation failed"): | |
| # Try MMS TTS first | |
| tts_result = generate_tts(am_response) | |
| # If MMS TTS fails, try gTTS | |
| if tts_result is None: | |
| print("[INFO] Trying gTTS fallback") | |
| tts_result = generate_tts_gtts(am_response) | |
| # If both TTS methods fail, use simple audio | |
| if tts_result is None: | |
| print("[INFO] Using simple audio fallback") | |
| tts_result = generate_simple_audio(am_response) | |
| if tts_result is not None: | |
| audio_data, sample_rate = tts_result | |
| audio_file_path = create_wav_file(audio_data, sample_rate) | |
| print(f"Audio generated successfully: {audio_file_path}") | |
| return asr_result, en_text, en_response, am_response, audio_file_path | |
| # --- Gradio UI --- | |
| with gr.Blocks(title="π Local Language AI Assistant") as demo: | |
| gr.Markdown("# π Local Language AI Assistant") | |
| gr.Markdown("ποΈ Speak **or upload** Amharic audio and get AI responses with synthesized Amharic speech!") | |
| with gr.Row(): | |
| audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="π€ Record or Upload your voice") | |
| submit_btn = gr.Button("Process", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| asr_output = gr.Textbox(label="ASR (Amharic text)") | |
| en_translation = gr.Textbox(label="Translated to English") | |
| en_response = gr.Textbox(label="Model Response (English)") | |
| am_response = gr.Textbox(label="Back Translated (Amharic)") | |
| audio_output = gr.Audio(label="Amharic TTS Output", type="filepath") | |
| submit_btn.click( | |
| fn=assistant_pipeline, | |
| inputs=audio_input, | |
| outputs=[asr_output, en_translation, en_response, am_response, audio_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |