naxemCDA's picture
modified process_audio function for 4 args
a775afa
# app.py
import os
import numpy as np
import gradio as gr
from transformers import pipeline
from langdetect import detect, LangDetectException
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
import torch
import soundfile as sf
from datasets import load_dataset
# Initialize models only once
print("Loading ASR model...")
asr_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
chunk_length_s=30
)
print("Loading grammar correction model...")
grammar_pipe = pipeline(
"text2text-generation",
model="pszemraj/flan-t5-large-grammar-synthesis"
)
print("Loading chat model...")
chat_pipe = pipeline(
"text-generation",
model="microsoft/DialoGPT-medium"
)
print("Loading TTS components...")
# Initialize TTS components
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
tts_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
print("Loading speaker embeddings...")
# Load speaker embeddings for male/female voices
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = {
"male": torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0),
"female": torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0)
}
print("All models loaded successfully!")
#####################################################################
###def process_audio(audio_path, voice_choice, conversation_history):
### """Process audio input and generate response"""
### # Transcribe audio
### try:
### result = asr_pipe(audio_path)
### user_input = result["text"]
### except Exception as e:
### print(f"ASR error: {e}")
### return None, "Could not process audio. Please try again.", conversation_history
###
### # Check if input is English
### try:
### if detect(user_input) != "en":
### return user_input, "You must try to speak in English for me to respond", conversation_history
### except LangDetectException:
### return user_input, "Could not detect language. Please speak clearly.", conversation_history
###
### # Grammar correction
### corrected_input = grammar_pipe(user_input, max_length=256)[0]["generated_text"]
###
### # Update conversation history
### conversation_history.append(f"{corrected_input}")
###
### # Generate conversational response
### chat_input = "\n".join(conversation_history[-4:]) # Keep last 4 exchanges
### response = chat_pipe(chat_input, max_length=256, pad_token_id=chat_pipe.tokenizer.eos_token_id)
### response_text = response[0]["generated_text"].split("Teacher:")[-1].strip()
###
### # Update conversation history
### conversation_history.append(f"Teacher: {response_text}")
###
### # Generate speech
### inputs = tts_processor(text=response_text, return_tensors="pt")
### speech = tts_model.generate_speech(
### inputs["input_ids"],
### speaker_embeddings[voice_choice],
### vocoder=tts_vocoder
### )
###
### # Save audio output
### output_audio = "response.wav"
### sf.write(output_audio, speech.numpy(), samplerate=16000)
###
### return user_input, response_text, output_audio, conversation_history
###########################################################################
def process_audio(audio_path, voice_choice, conversation_history):
"""Process audio input and generate response"""
# Transcribe audio
try:
result = asr_pipe(audio_path)
user_input = result["text"]
except Exception as e:
print(f"ASR error: {e}")
# Return 4 values, including placeholders for the missing outputs
return None, "Could not process audio. Please try again.", None, conversation_history
# Check if input is English
try:
if detect(user_input) != "en":
# Return 4 values
return user_input, "You must try to speak in English for me to respond", None, conversation_history
except LangDetectException:
# Return 4 values
return user_input, "Could not detect language. Please speak clearly.", None, conversation_history
# Grammar correction
corrected_input = grammar_pipe(user_input, max_length=256)[0]["generated_text"]
# Update conversation history
conversation_history.append(f"{corrected_input}")
# Generate conversational response
chat_input = "\n".join(conversation_history[-4:]) # Keep last 4 exchanges
response = chat_pipe(chat_input, max_length=256, pad_token_id=chat_pipe.tokenizer.eos_token_id)
response_text = response[0]["generated_text"].split("Teacher:")[-1].strip()
# Update conversation history
conversation_history.append(f"Teacher: {response_text}")
# Generate speech
inputs = tts_processor(text=response_text, return_tensors="pt")
speech = tts_model.generate_speech(
inputs["input_ids"],
speaker_embeddings[voice_choice],
vocoder=tts_vocoder
)
# Save audio output
output_audio = "response.wav"
sf.write(output_audio, speech.numpy(), samplerate=16000)
# Return 4 values
return user_input, response_text, output_audio, conversation_history
########################################################################
# Gradio interface
with gr.Blocks(title="Audio English Teacher") as demo:
gr.Markdown("# 🎓 Audio English Teacher")
gr.Markdown("Practice English conversation with AI correction and feedback!")
with gr.Row():
voice_choice = gr.Radio(
["male", "female"],
label="Select Voice",
value="female"
)
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="Speak in English"
)
history_state = gr.State([])
with gr.Column():
original_text = gr.Textbox(label="What you said")
corrected_output = gr.Textbox(label="Corrected English")
audio_output = gr.Audio(label="Teacher's Response", autoplay=True)
audio_input.stop_recording(
fn=process_audio,
inputs=[audio_input, voice_choice, history_state],
outputs=[original_text, corrected_output, audio_output, history_state]
)
gr.Examples(
examples=[
["I goes to school yesterday", "male"],
["She don't like apples", "female"],
["We was happy for the results", "male"]
],
inputs=[original_text, voice_choice], # Changed inputs to match the function
outputs=[original_text, corrected_output, audio_output, history_state],
fn=process_audio
)
if __name__ == "__main__":
demo.launch() # No need for share=True on Hugging Face Spaces