Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import numpy as np | |
| import gradio as gr | |
| from transformers import pipeline | |
| from langdetect import detect, LangDetectException | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| import torch | |
| import soundfile as sf | |
| from datasets import load_dataset | |
| # Initialize models only once | |
| print("Loading ASR model...") | |
| asr_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-small", | |
| chunk_length_s=30 | |
| ) | |
| print("Loading grammar correction model...") | |
| grammar_pipe = pipeline( | |
| "text2text-generation", | |
| model="pszemraj/flan-t5-large-grammar-synthesis" | |
| ) | |
| print("Loading chat model...") | |
| chat_pipe = pipeline( | |
| "text-generation", | |
| model="microsoft/DialoGPT-medium" | |
| ) | |
| print("Loading TTS components...") | |
| # Initialize TTS components | |
| tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") | |
| tts_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
| print("Loading speaker embeddings...") | |
| # Load speaker embeddings for male/female voices | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| speaker_embeddings = { | |
| "male": torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0), | |
| "female": torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0) | |
| } | |
| print("All models loaded successfully!") | |
| ##################################################################### | |
| ###def process_audio(audio_path, voice_choice, conversation_history): | |
| ### """Process audio input and generate response""" | |
| ### # Transcribe audio | |
| ### try: | |
| ### result = asr_pipe(audio_path) | |
| ### user_input = result["text"] | |
| ### except Exception as e: | |
| ### print(f"ASR error: {e}") | |
| ### return None, "Could not process audio. Please try again.", conversation_history | |
| ### | |
| ### # Check if input is English | |
| ### try: | |
| ### if detect(user_input) != "en": | |
| ### return user_input, "You must try to speak in English for me to respond", conversation_history | |
| ### except LangDetectException: | |
| ### return user_input, "Could not detect language. Please speak clearly.", conversation_history | |
| ### | |
| ### # Grammar correction | |
| ### corrected_input = grammar_pipe(user_input, max_length=256)[0]["generated_text"] | |
| ### | |
| ### # Update conversation history | |
| ### conversation_history.append(f"{corrected_input}") | |
| ### | |
| ### # Generate conversational response | |
| ### chat_input = "\n".join(conversation_history[-4:]) # Keep last 4 exchanges | |
| ### response = chat_pipe(chat_input, max_length=256, pad_token_id=chat_pipe.tokenizer.eos_token_id) | |
| ### response_text = response[0]["generated_text"].split("Teacher:")[-1].strip() | |
| ### | |
| ### # Update conversation history | |
| ### conversation_history.append(f"Teacher: {response_text}") | |
| ### | |
| ### # Generate speech | |
| ### inputs = tts_processor(text=response_text, return_tensors="pt") | |
| ### speech = tts_model.generate_speech( | |
| ### inputs["input_ids"], | |
| ### speaker_embeddings[voice_choice], | |
| ### vocoder=tts_vocoder | |
| ### ) | |
| ### | |
| ### # Save audio output | |
| ### output_audio = "response.wav" | |
| ### sf.write(output_audio, speech.numpy(), samplerate=16000) | |
| ### | |
| ### return user_input, response_text, output_audio, conversation_history | |
| ########################################################################### | |
| def process_audio(audio_path, voice_choice, conversation_history): | |
| """Process audio input and generate response""" | |
| # Transcribe audio | |
| try: | |
| result = asr_pipe(audio_path) | |
| user_input = result["text"] | |
| except Exception as e: | |
| print(f"ASR error: {e}") | |
| # Return 4 values, including placeholders for the missing outputs | |
| return None, "Could not process audio. Please try again.", None, conversation_history | |
| # Check if input is English | |
| try: | |
| if detect(user_input) != "en": | |
| # Return 4 values | |
| return user_input, "You must try to speak in English for me to respond", None, conversation_history | |
| except LangDetectException: | |
| # Return 4 values | |
| return user_input, "Could not detect language. Please speak clearly.", None, conversation_history | |
| # Grammar correction | |
| corrected_input = grammar_pipe(user_input, max_length=256)[0]["generated_text"] | |
| # Update conversation history | |
| conversation_history.append(f"{corrected_input}") | |
| # Generate conversational response | |
| chat_input = "\n".join(conversation_history[-4:]) # Keep last 4 exchanges | |
| response = chat_pipe(chat_input, max_length=256, pad_token_id=chat_pipe.tokenizer.eos_token_id) | |
| response_text = response[0]["generated_text"].split("Teacher:")[-1].strip() | |
| # Update conversation history | |
| conversation_history.append(f"Teacher: {response_text}") | |
| # Generate speech | |
| inputs = tts_processor(text=response_text, return_tensors="pt") | |
| speech = tts_model.generate_speech( | |
| inputs["input_ids"], | |
| speaker_embeddings[voice_choice], | |
| vocoder=tts_vocoder | |
| ) | |
| # Save audio output | |
| output_audio = "response.wav" | |
| sf.write(output_audio, speech.numpy(), samplerate=16000) | |
| # Return 4 values | |
| return user_input, response_text, output_audio, conversation_history | |
| ######################################################################## | |
| # Gradio interface | |
| with gr.Blocks(title="Audio English Teacher") as demo: | |
| gr.Markdown("# 🎓 Audio English Teacher") | |
| gr.Markdown("Practice English conversation with AI correction and feedback!") | |
| with gr.Row(): | |
| voice_choice = gr.Radio( | |
| ["male", "female"], | |
| label="Select Voice", | |
| value="female" | |
| ) | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="Speak in English" | |
| ) | |
| history_state = gr.State([]) | |
| with gr.Column(): | |
| original_text = gr.Textbox(label="What you said") | |
| corrected_output = gr.Textbox(label="Corrected English") | |
| audio_output = gr.Audio(label="Teacher's Response", autoplay=True) | |
| audio_input.stop_recording( | |
| fn=process_audio, | |
| inputs=[audio_input, voice_choice, history_state], | |
| outputs=[original_text, corrected_output, audio_output, history_state] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["I goes to school yesterday", "male"], | |
| ["She don't like apples", "female"], | |
| ["We was happy for the results", "male"] | |
| ], | |
| inputs=[original_text, voice_choice], # Changed inputs to match the function | |
| outputs=[original_text, corrected_output, audio_output, history_state], | |
| fn=process_audio | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() # No need for share=True on Hugging Face Spaces | |