GihonTech / app.py
Minte
fix: refactor model loading and enhance ASR and translation functionality with SeamlessM4T integration
cb4630e
import traceback
import soundfile as sf
import torch
import numpy as np
from transformers import (
SeamlessM4TModel, AutoProcessor,
pipeline, VitsModel, AutoTokenizer
)
import gradio as gr
import resampy
import tempfile
import subprocess
# --- Load SeamlessM4T model for ASR and translation ---
try:
model_id = "facebook/seamless-m4t-v2-large"
processor = AutoProcessor.from_pretrained(model_id)
model = SeamlessM4TModel.from_pretrained(model_id).to("cpu")
print("[INFO] SeamlessM4T model loaded for ASR and translation.")
except Exception as e:
print("[ERROR] Failed to load SeamlessM4T model:", e)
traceback.print_exc()
model = None
processor = None
# --- Load chat model ---
try:
chat_model = pipeline("text2text-generation", model="google/flan-t5-base")
print("[INFO] Chat model loaded successfully.")
except Exception as e:
print("[ERROR] Failed to load chat model:", e)
traceback.print_exc()
chat_model = None
# --- Load TTS model (Facebook MMS for Amharic) ---
try:
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-amh")
tts_model = VitsModel.from_pretrained("facebook/mms-tts-amh").to("cpu")
print("[INFO] Facebook MMS TTS model for Amharic loaded successfully.")
except Exception as e:
print("[ERROR] Failed to load Facebook MMS TTS model:", e)
traceback.print_exc()
tts_tokenizer = None
tts_model = None
# --- Romanization helper ---
def romanize(text):
try:
result = subprocess.run(["uroman"], input=text.encode("utf-8"), stdout=subprocess.PIPE)
return result.stdout.decode("utf-8").strip()
except Exception as e:
print("[ERROR] Romanization failed:", e)
return text # fallback
# --- ASR with SeamlessM4T ---
def transcribe_amharic(audio_file):
if model is None or processor is None:
return "ASR Model loading failed"
try:
audio, sr = sf.read(audio_file)
if audio.ndim > 1:
audio = audio.mean(axis=1)
audio = resampy.resample(audio, sr, 16000)
# Direct Amharic transcription
inputs = processor(audio=audio, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
generated_ids = model.generate(
**inputs,
tgt_lang="amh",
generate_speech=False
)
transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
return transcription.strip()
except Exception as e:
print("[ERROR] ASR transcription failed:", e)
traceback.print_exc()
return f"ASR failed: {str(e)[:50]}..."
# --- Translation with SeamlessM4T (Amharic to English) ---
def translate_am_to_en(amharic_text):
if model is None or processor is None:
return "Translation model not loaded"
try:
# Translate Amharic to English using SeamlessM4T
text_inputs = processor(text=amharic_text, src_lang="amh", return_tensors="pt")
with torch.no_grad():
output_tokens = model.generate(
**text_inputs,
tgt_lang="eng",
generate_speech=False
)
translated_text = processor.decode(output_tokens[0], skip_special_tokens=True)
return translated_text.strip()
except Exception as e:
print("[ERROR] Translation failed:", e)
traceback.print_exc()
return f"Translation failed: {str(e)[:50]}..."
# --- Back translation with SeamlessM4T (English to Amharic) ---
def back_translate_en_to_am(en_text):
if model is None or processor is None:
return "Back translation model not loaded"
try:
# Translate English back to Amharic using SeamlessM4T
text_inputs = processor(text=en_text, src_lang="eng", return_tensors="pt")
with torch.no_grad():
output_tokens = model.generate(
**text_inputs,
tgt_lang="amh",
generate_speech=False
)
am_response = processor.decode(output_tokens[0], skip_special_tokens=True)
return am_response.strip()
except Exception as e:
print("[ERROR] Back translation failed:", e)
traceback.print_exc()
return f"Back translation failed: {str(e)[:50]}..."
# --- Chat response ---
def generate_chat_response(text):
if chat_model is None:
return "Chat model not loaded"
try:
# Add context to make responses more meaningful
prompt = f"Respond to this in a helpful and conversational way: {text}"
response = chat_model(prompt, max_length=150, num_beams=5, temperature=0.7, do_sample=True)[0]['generated_text']
return response.strip()
except Exception as e:
print("[ERROR] Chat generation failed:", e)
return f"Chat failed: {str(e)[:50]}..."
# --- TTS with Facebook MMS ---
def generate_tts(text):
if tts_model is None or tts_tokenizer is None:
print("[ERROR] TTS model not loaded")
return None
try:
if not text.strip():
return None
# Tokenize text and generate speech
inputs = tts_tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = tts_model(**inputs)
speech = output.waveform
# Convert to numpy and normalize
audio_data = speech.cpu().numpy().squeeze()
max_val = np.max(np.abs(audio_data))
if max_val > 0:
audio_data = audio_data / max_val
return audio_data, tts_model.config.sampling_rate
except Exception as e:
print("[ERROR] MMS TTS generation failed:", e)
traceback.print_exc()
return None
# --- Alternative TTS using gTTS (fallback) ---
def generate_tts_gtts(text):
try:
from gtts import gTTS
import io
tts = gTTS(text=text, lang='am')
fp = io.BytesIO()
tts.write_to_fp(fp)
fp.seek(0)
# Convert to numpy array for consistency
audio, sr = sf.read(fp)
return audio, sr
except Exception as e:
print("[ERROR] gTTS failed:", e)
return None
# --- Simple audio fallback ---
def generate_simple_audio(text):
try:
sampling_rate = 22050
duration = min(3.0, max(1.0, len(text)/10))
t = np.linspace(0, duration, int(sampling_rate*duration), endpoint=False)
frequency = 300 + (hash(text) % 300)
audio_data = 0.5 * np.sin(2 * np.pi * frequency * t)
return audio_data, sampling_rate
except Exception as e:
print("[ERROR] Simple audio generation failed:", e)
return None
# --- Create WAV file ---
def create_wav_file(audio_array, sample_rate):
try:
if audio_array is None:
return None
if audio_array.ndim > 1:
audio_array = audio_array.flatten()
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
sf.write(temp_file.name, audio_array, sample_rate)
return temp_file.name
except Exception as e:
print("[ERROR] WAV file creation failed:", e)
traceback.print_exc()
return None
# --- Assistant pipeline ---
def assistant_pipeline(audio):
if not audio:
return "No audio", "", "", "", None
# Step 1: ASR with SeamlessM4T
asr_result = transcribe_amharic(audio)
print(f"ASR Result: {asr_result}")
# Step 2: Translation with SeamlessM4T
en_text = translate_am_to_en(asr_result)
print(f"English Translation: {en_text}")
# Step 3: Chat response
en_response = generate_chat_response(en_text)
print(f"Chat Response: {en_response}")
# Step 4: Back translation with SeamlessM4T
am_response = back_translate_en_to_am(en_response)
print(f"Amharic Response: {am_response}")
# Step 5: TTS
audio_file_path = None
if am_response and not am_response.startswith("Back translation failed"):
# Try MMS TTS first
tts_result = generate_tts(am_response)
# If MMS TTS fails, try gTTS
if tts_result is None:
print("[INFO] Trying gTTS fallback")
tts_result = generate_tts_gtts(am_response)
# If both TTS methods fail, use simple audio
if tts_result is None:
print("[INFO] Using simple audio fallback")
tts_result = generate_simple_audio(am_response)
if tts_result is not None:
audio_data, sample_rate = tts_result
audio_file_path = create_wav_file(audio_data, sample_rate)
print(f"Audio generated successfully: {audio_file_path}")
return asr_result, en_text, en_response, am_response, audio_file_path
# --- Gradio UI ---
with gr.Blocks(title="🌍 Local Language AI Assistant") as demo:
gr.Markdown("# 🌍 Local Language AI Assistant")
gr.Markdown("πŸŽ™οΈ Speak **or upload** Amharic audio and get AI responses with synthesized Amharic speech!")
with gr.Row():
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="🎀 Record or Upload your voice")
submit_btn = gr.Button("Process", variant="primary")
with gr.Row():
with gr.Column():
asr_output = gr.Textbox(label="ASR (Amharic text)")
en_translation = gr.Textbox(label="Translated to English")
en_response = gr.Textbox(label="Model Response (English)")
am_response = gr.Textbox(label="Back Translated (Amharic)")
audio_output = gr.Audio(label="Amharic TTS Output", type="filepath")
submit_btn.click(
fn=assistant_pipeline,
inputs=audio_input,
outputs=[asr_output, en_translation, en_response, am_response, audio_output]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)