sdkbro's picture
Create app.py
490f6f1 verified
import gradio as gr
import whisper
from transformers import MarianMTModel, MarianTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from TTS.api import TTS
import torch
import os
import tempfile
# -----------------------------
# Model Loading Section
# -----------------------------
# Load Whisper model
print("Loading Whisper model...")
stt_model = whisper.load_model("tiny") # Use "tiny" for faster performance and lower resource usage
print("Whisper model loaded.")
# Function to load MarianMT models dynamically and cache them
translation_models = {}
def get_translation_model(src_lang, tgt_lang):
"""
Dynamically load and cache MarianMT translation models based on source and target languages.
"""
key = f"{src_lang}-{tgt_lang}"
if key in translation_models:
return translation_models[key]
model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}'
try:
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
translation_models[key] = (tokenizer, model)
print(f"Loaded translation model: {model_name}")
return tokenizer, model
except Exception as e:
print(f"Translation model {model_name} not found. Error: {e}")
return None, None
# Load Language Model (GPT-Neo)
print("Loading Language Model...")
lm_model_name = "EleutherAI/gpt-neo-125M" # Smaller model suitable for free tier
lm_tokenizer = AutoTokenizer.from_pretrained(lm_model_name)
lm_model = AutoModelForCausalLM.from_pretrained(lm_model_name)
print("Language model loaded.")
# Load TTS model
print("Loading TTS model...")
tts_model = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
print("TTS model loaded.")
# -----------------------------
# Function Definitions
# -----------------------------
def speech_to_text(audio_path):
"""
Transcribe audio to text and detect language using Whisper.
"""
result = stt_model.transcribe(audio_path)
text = result["text"]
detected_lang = result["language"]
print(f"Transcribed Text: {text}")
print(f"Detected Language: {detected_lang}")
return text, detected_lang
def translate_text(text, src_lang, tgt_lang='en'):
"""
Translate text from src_lang to tgt_lang using MarianMT.
"""
if src_lang == tgt_lang:
print("No translation needed.")
return text
tokenizer, model = get_translation_model(src_lang, tgt_lang)
if tokenizer is None or model is None:
print(f"No translation model found for {src_lang} to {tgt_lang}. Returning original text.")
return text # Return original text if translation model not found
inputs = tokenizer(text, return_tensors="pt", padding=True)
translated = model.generate(**inputs)
translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
print(f"Translated Text ({src_lang} -> {tgt_lang}): {translated_text}")
return translated_text
def generate_response(prompt):
"""
Generate a response using the language model.
"""
inputs = lm_tokenizer(prompt, return_tensors="pt")
outputs = lm_model.generate(inputs.input_ids, max_length=150, do_sample=True, temperature=0.7)
response = lm_tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt):].strip()
print(f"AI Response: {response}")
return response
def text_to_speech(text, lang='en'):
"""
Convert text to speech using Coqui TTS.
"""
if lang != 'en':
# Extend with multilingual TTS models as needed
print(f"TTS for language '{lang}' not implemented. Using English TTS.")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tts_model.tts_to_file(text=text, file_path=tmp.name)
print(f"Generated TTS audio at: {tmp.name}")
return tmp.name
def process_audio(audio):
"""
Full processing pipeline: Speech-to-Text -> Translate -> Generate Response -> Translate Back -> Text-to-Speech
"""
# Check file size (e.g., limit to 10MB)
if audio.size > 10 * 1024 * 1024:
print("Uploaded audio file is too large.")
return None # Or return an error message/audio
# Save uploaded audio to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio.read())
tmp_path = tmp.name
print(f"Audio saved to temporary file: {tmp_path}")
try:
# Step 1: Speech-to-Text
user_text, detected_lang = speech_to_text(tmp_path)
# Step 2: Translate to English
translated_text = translate_text(user_text, src_lang=detected_lang, tgt_lang='en')
# Step 3: Generate Response
ai_response = generate_response(translated_text)
# Step 4: Translate Back to User's Language
translated_response = translate_text(ai_response, src_lang='en', tgt_lang=detected_lang)
# Step 5: Text-to-Speech
response_audio_path = text_to_speech(translated_response, lang=detected_lang)
# Read the generated audio
with open(response_audio_path, "rb") as f:
response_audio = f.read()
except Exception as e:
print(f"Error during processing: {e}")
# Optionally, return an error message or a default audio response
return None
finally:
# Clean up temporary files
os.remove(tmp_path)
if 'response_audio_path' in locals() and os.path.exists(response_audio_path):
os.remove(response_audio_path)
print("Temporary files cleaned up.")
return response_audio
# -----------------------------
# Gradio Interface Definition
# -----------------------------
iface = gr.Interface(
fn=process_audio,
inputs=gr.Audio(source="upload", type="file", label="Upload Your Audio"),
outputs=gr.Audio(type="file", label="AI Response"),
title="Multilingual Voice Interaction",
description="Upload an audio file in any supported language. The system will respond with an audio reply in the same language.",
examples=[
# To add examples, upload example audio files to your Space and reference their paths here
# ["example1.wav"],
# ["example2.wav"],
],
allow_flagging="never", # Disable flagging to prevent misuse
)
if __name__ == "__main__":
iface.launch()