Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from groq import Groq, APIConnectionError, AuthenticationError | |
| from transformers import ( | |
| pipeline, | |
| AutoTokenizer, | |
| AutoModelForQuestionAnswering, | |
| AutoProcessor, | |
| AutoModelForSpeechSeq2Seq, | |
| ) | |
| from espnet2.bin.tts_inference import Text2Speech | |
| from PIL import Image | |
| import easyocr | |
| import soundfile as sf | |
| from pydub import AudioSegment | |
| import io | |
| from streamlit_webrtc import webrtc_streamer, WebRtcMode, AudioProcessorBase | |
| import av | |
| import numpy as np | |
| # Load Groq API key from environment variables | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| st.error("Groq API key not found. Please add it to the Hugging Face Space Secrets.") | |
| st.stop() | |
| # Initialize Groq client | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| # OCR Function | |
| def extract_text_from_image(image): | |
| reader = easyocr.Reader(['en']) | |
| result = reader.readtext(image) | |
| extracted_text = " ".join([detection[1] for detection in result]) | |
| return extracted_text | |
| # Question Answering Function (DistilBERT) | |
| def load_qa_model(): | |
| model_name = "distilbert/distilbert-base-cased-distilled-squad" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
| nlp = pipeline('question-answering', model=model, tokenizer=tokenizer) | |
| return nlp | |
| def answer_question(context, question, qa_model): | |
| result = qa_model({'question': question, 'context': context}) | |
| return result['answer'] | |
| # Load models for voice chatbot | |
| def load_voice_models(): | |
| # Speech-to-Text | |
| processor = AutoProcessor.from_pretrained("openai/whisper-small") | |
| stt_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small") | |
| stt_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=stt_model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| return_timestamps=True # Enable timestamps for long-form audio | |
| ) | |
| # Text-to-Speech | |
| tts_model = Text2Speech.from_pretrained("espnet/espnet_tts_vctk_espnet_spk_voxceleb12_rawnet") | |
| return stt_pipe, tts_model | |
| # Groq API Function | |
| def groq_chat(prompt): | |
| try: | |
| chat_completion = groq_client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama-3.3-70b-versatile", | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except APIConnectionError as e: | |
| return f"Groq API Connection Error: {e}" | |
| except AuthenticationError as e: | |
| return f"Groq API Authentication Error: {e}" | |
| except Exception as e: | |
| return f"General Groq API Error: {e}" | |
| # Streamlit App | |
| def main(): | |
| st.title("Multi-Modal Chatbot: Image Text & Voice") | |
| # Sidebar for mode selection | |
| mode = st.sidebar.radio("Select Mode", ["Image Text & QA", "Voice Chatbot"]) | |
| if mode == "Image Text & QA": | |
| # Image Text Extraction & QA | |
| st.header("Image Text Extraction & Question Answering") | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| if st.button("Extract Text and Enable Question Answering"): | |
| with st.spinner("Extracting text..."): | |
| extracted_text = extract_text_from_image(image) | |
| st.write("Extracted Text:") | |
| st.write(extracted_text) | |
| qa_model = load_qa_model() | |
| question = st.text_input("Ask a question about the image text:") | |
| if st.button("Answer"): | |
| if question: | |
| with st.spinner("Answering..."): | |
| answer = answer_question(extracted_text, question, qa_model) | |
| st.write("Answer:", answer) | |
| else: | |
| st.warning("Please enter a question.") | |
| elif mode == "Voice Chatbot": | |
| # Voice Chatbot | |
| st.header("Voice-Enabled Chatbot") | |
| # Audio recorder | |
| st.write("Record your voice:") | |
| webrtc_ctx = webrtc_streamer( | |
| key="audio-recorder", | |
| mode=WebRtcMode.SENDONLY, | |
| audio_processor_factory=AudioRecorder, | |
| media_stream_constraints={"audio": True, "video": False}, | |
| ) | |
| if webrtc_ctx.audio_processor: | |
| st.write("Recording... Press 'Stop' to finish recording.") | |
| # Save recorded audio to a WAV file | |
| if st.button("Stop and Process Recording"): | |
| audio_frames = webrtc_ctx.audio_processor.audio_frames | |
| if audio_frames: | |
| # Combine audio frames into a single array | |
| audio_data = np.concatenate(audio_frames) | |
| # Save as WAV file | |
| sf.write("recorded_audio.wav", audio_data, samplerate=16000) | |
| st.success("Recording saved as recorded_audio.wav") | |
| # Process the recorded audio | |
| speech, _ = sf.read("recorded_audio.wav") | |
| output = stt_pipe(speech) # Transcribe with timestamps | |
| # Debug: Print the transcribed text | |
| st.write("Transcribed Text:", output['text']) | |
| # Display the text with timestamps (optional) | |
| if 'chunks' in output: | |
| st.write("Transcribed Text with Timestamps:") | |
| for chunk in output['chunks']: | |
| st.write(f"{chunk['timestamp'][0]:.2f} - {chunk['timestamp'][1]:.2f}: {chunk['text']}") | |
| # Generate response using Groq API | |
| try: | |
| # Debug: Print the input text | |
| st.write("Input Text:", output['text']) | |
| chat_completion = groq_client.chat.completions.create( | |
| messages=[{"role": "user", "content": output['text']}], | |
| model="mixtral-8x7b-32768", | |
| temperature=0.5, | |
| max_tokens=1024, | |
| ) | |
| # Debug: Print the API response | |
| st.write("API Response:", chat_completion) | |
| # Extract the generated response | |
| response = chat_completion.choices[0].message.content | |
| st.write("Generated Response:", response) | |
| # Convert response to speech | |
| speech, *_ = tts_model(response, spembs=tts_model.spembs[0]) # Use the first speaker embedding | |
| # Debug: Print the TTS output | |
| st.write("TTS Output:", speech) | |
| # Save and play the speech | |
| sf.write("response.wav", speech, 22050) | |
| st.audio("response.wav") | |
| except Exception as e: | |
| st.error(f"Error generating response: {e}") | |
| else: | |
| st.error("No audio recorded. Please try again.") | |
| # Groq Chat Section (Common for both modes) | |
| st.subheader("General Chat (Powered by Groq)") | |
| groq_prompt = st.text_input("Enter your message:") | |
| if st.button("Send"): | |
| if groq_prompt: | |
| with st.spinner("Generating response..."): | |
| groq_response = groq_chat(groq_prompt) | |
| st.write("Response:", groq_response) | |
| else: | |
| st.warning("Please enter a message.") | |
| # Audio recorder class | |
| class AudioRecorder(AudioProcessorBase): | |
| def __init__(self): | |
| self.audio_frames = [] | |
| def recv(self, frame: av.AudioFrame) -> av.AudioFrame: | |
| self.audio_frames.append(frame.to_ndarray()) | |
| return frame | |
| if __name__ == "__main__": | |
| main() |