Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import uuid | |
| import os | |
| import speech_recognition as sr | |
| from gtts import gTTS | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
| # Initialize the model and tokenizer | |
| model_name = "gpt2" # You can change this to any other suitable model | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Create a text-generation pipeline | |
| text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=100) | |
| # Wrap the pipeline in a LangChain HuggingFacePipeline | |
| llm = HuggingFacePipeline(pipeline=text_generation) | |
| # Initialize the prompt template with improved instructions | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """ | |
| You are a helpful AI assistant. Your task is to engage in conversation with users, | |
| answer their questions, and assist them with various tasks. | |
| Communicate politely and maintain focus on the user's needs. | |
| Keep responses concise, typically two to three sentences. | |
| Always provide a complete and relevant response to the user's input. | |
| Do not use generic greetings or incomplete phrases like "Hello?". | |
| If you don't understand or can't answer, say so clearly and ask for clarification. | |
| """), | |
| MessagesPlaceholder(variable_name="history"), | |
| ("human", "{input}"), | |
| ]) | |
| runnable = prompt | llm | |
| with_message_history = RunnableWithMessageHistory( | |
| runnable, | |
| lambda session_id: ChatMessageHistory(), | |
| input_messages_key="input", | |
| history_messages_key="history", | |
| ) | |
| def text_to_speech(text, file_name): | |
| if not text or not text.strip(): | |
| print("Warning: Empty text provided to text_to_speech function") | |
| return None | |
| try: | |
| tts = gTTS(text=text, lang='en', slow=False) | |
| file_path = os.path.join(os.getcwd(), file_name) | |
| tts.save(file_path) | |
| return file_path | |
| except Exception as e: | |
| print(f"Error in text_to_speech: {str(e)}") | |
| return None | |
| def speech_to_text(audio): | |
| if audio is None: | |
| return "No audio input received." | |
| recognizer = sr.Recognizer() | |
| try: | |
| with sr.AudioFile(audio) as source: | |
| audio_data = recognizer.record(source) | |
| try: | |
| text = recognizer.recognize_google(audio_data) | |
| print(f"Recognized text: {text}") | |
| return text | |
| except sr.UnknownValueError: | |
| return "Speech recognition could not understand the audio" | |
| except sr.RequestError: | |
| return "Could not request results from the speech recognition service" | |
| except Exception as e: | |
| return f"Error processing audio: {str(e)}" | |
| def chat_function(input_type, text_input=None, audio_input=None, history=None): | |
| if history is None: | |
| history = [] | |
| if input_type == "text": | |
| user_input = text_input | |
| elif input_type == "audio": | |
| if audio_input is not None: | |
| user_input = speech_to_text(audio_input) | |
| else: | |
| user_input = "No audio input received." | |
| else: | |
| return history, history, None | |
| print(f"User input: {user_input}") # Debug information | |
| try: | |
| # Get LLM response | |
| response = with_message_history.invoke( | |
| {"input": user_input}, | |
| config={"configurable": {"session_id": "chat_history"}}, | |
| ) | |
| print(f"LLM response: {response}") # Debug information | |
| # Post-process the response | |
| if not response or not response.strip() or len(response.split()) < 3: | |
| response = "I apologize, but I couldn't generate a meaningful response. Could you please rephrase your question or provide more context?" | |
| # Generate audio for LLM response | |
| audio_file = f"response_{uuid.uuid4()}.mp3" | |
| audio_path = text_to_speech(response, audio_file) | |
| # Update history in the correct format | |
| history.append((user_input, response)) | |
| return history, history, audio_path | |
| except Exception as e: | |
| print(f"Error in chat_function: {str(e)}") | |
| return history, history, None | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| text_input = gr.Textbox(placeholder="Type your message here...") | |
| audio_input = gr.Audio(sources=['microphone'], type="filepath") | |
| with gr.Row(): | |
| text_button = gr.Button("Send Text") | |
| audio_button = gr.Button("Send Audio") | |
| audio_output = gr.Audio() | |
| def on_audio_change(audio): | |
| if audio is not None: | |
| return speech_to_text(audio) | |
| return "" | |
| audio_input.change(on_audio_change, inputs=[audio_input], outputs=[text_input]) | |
| text_button.click(chat_function, inputs=[gr.Textbox(value="text"), text_input, audio_input, chatbot], outputs=[chatbot, chatbot, audio_output]) | |
| audio_button.click(chat_function, inputs=[gr.Textbox(value="audio"), text_input, audio_input, chatbot], outputs=[chatbot, chatbot, audio_output]) | |
| demo.launch() |