merged_chat / app.py
Dua Rajper
Create app.py
d2af615 verified
import os
import streamlit as st
from groq import Groq, APIConnectionError, AuthenticationError
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForQuestionAnswering,
AutoProcessor,
AutoModelForSpeechSeq2Seq,
)
from espnet2.bin.tts_inference import Text2Speech
from PIL import Image
import easyocr
import soundfile as sf
from pydub import AudioSegment
import io
from streamlit_webrtc import webrtc_streamer, WebRtcMode, AudioProcessorBase
import av
import numpy as np
# Load Groq API key from environment variables
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
st.error("Groq API key not found. Please add it to the Hugging Face Space Secrets.")
st.stop()
# Initialize Groq client
groq_client = Groq(api_key=GROQ_API_KEY)
# OCR Function
def extract_text_from_image(image):
reader = easyocr.Reader(['en'])
result = reader.readtext(image)
extracted_text = " ".join([detection[1] for detection in result])
return extracted_text
# Question Answering Function (DistilBERT)
@st.cache_resource
def load_qa_model():
model_name = "distilbert/distilbert-base-cased-distilled-squad"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)
return nlp
def answer_question(context, question, qa_model):
result = qa_model({'question': question, 'context': context})
return result['answer']
# Load models for voice chatbot
@st.cache_resource
def load_voice_models():
# Speech-to-Text
processor = AutoProcessor.from_pretrained("openai/whisper-small")
stt_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small")
stt_pipe = pipeline(
"automatic-speech-recognition",
model=stt_model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
return_timestamps=True # Enable timestamps for long-form audio
)
# Text-to-Speech
tts_model = Text2Speech.from_pretrained("espnet/espnet_tts_vctk_espnet_spk_voxceleb12_rawnet")
return stt_pipe, tts_model
# Groq API Function
def groq_chat(prompt):
try:
chat_completion = groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
)
return chat_completion.choices[0].message.content
except APIConnectionError as e:
return f"Groq API Connection Error: {e}"
except AuthenticationError as e:
return f"Groq API Authentication Error: {e}"
except Exception as e:
return f"General Groq API Error: {e}"
# Streamlit App
def main():
st.title("Multi-Modal Chatbot: Image Text & Voice")
# Sidebar for mode selection
mode = st.sidebar.radio("Select Mode", ["Image Text & QA", "Voice Chatbot"])
if mode == "Image Text & QA":
# Image Text Extraction & QA
st.header("Image Text Extraction & Question Answering")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
if st.button("Extract Text and Enable Question Answering"):
with st.spinner("Extracting text..."):
extracted_text = extract_text_from_image(image)
st.write("Extracted Text:")
st.write(extracted_text)
qa_model = load_qa_model()
question = st.text_input("Ask a question about the image text:")
if st.button("Answer"):
if question:
with st.spinner("Answering..."):
answer = answer_question(extracted_text, question, qa_model)
st.write("Answer:", answer)
else:
st.warning("Please enter a question.")
elif mode == "Voice Chatbot":
# Voice Chatbot
st.header("Voice-Enabled Chatbot")
# Audio recorder
st.write("Record your voice:")
webrtc_ctx = webrtc_streamer(
key="audio-recorder",
mode=WebRtcMode.SENDONLY,
audio_processor_factory=AudioRecorder,
media_stream_constraints={"audio": True, "video": False},
)
if webrtc_ctx.audio_processor:
st.write("Recording... Press 'Stop' to finish recording.")
# Save recorded audio to a WAV file
if st.button("Stop and Process Recording"):
audio_frames = webrtc_ctx.audio_processor.audio_frames
if audio_frames:
# Combine audio frames into a single array
audio_data = np.concatenate(audio_frames)
# Save as WAV file
sf.write("recorded_audio.wav", audio_data, samplerate=16000)
st.success("Recording saved as recorded_audio.wav")
# Process the recorded audio
speech, _ = sf.read("recorded_audio.wav")
output = stt_pipe(speech) # Transcribe with timestamps
# Debug: Print the transcribed text
st.write("Transcribed Text:", output['text'])
# Display the text with timestamps (optional)
if 'chunks' in output:
st.write("Transcribed Text with Timestamps:")
for chunk in output['chunks']:
st.write(f"{chunk['timestamp'][0]:.2f} - {chunk['timestamp'][1]:.2f}: {chunk['text']}")
# Generate response using Groq API
try:
# Debug: Print the input text
st.write("Input Text:", output['text'])
chat_completion = groq_client.chat.completions.create(
messages=[{"role": "user", "content": output['text']}],
model="mixtral-8x7b-32768",
temperature=0.5,
max_tokens=1024,
)
# Debug: Print the API response
st.write("API Response:", chat_completion)
# Extract the generated response
response = chat_completion.choices[0].message.content
st.write("Generated Response:", response)
# Convert response to speech
speech, *_ = tts_model(response, spembs=tts_model.spembs[0]) # Use the first speaker embedding
# Debug: Print the TTS output
st.write("TTS Output:", speech)
# Save and play the speech
sf.write("response.wav", speech, 22050)
st.audio("response.wav")
except Exception as e:
st.error(f"Error generating response: {e}")
else:
st.error("No audio recorded. Please try again.")
# Groq Chat Section (Common for both modes)
st.subheader("General Chat (Powered by Groq)")
groq_prompt = st.text_input("Enter your message:")
if st.button("Send"):
if groq_prompt:
with st.spinner("Generating response..."):
groq_response = groq_chat(groq_prompt)
st.write("Response:", groq_response)
else:
st.warning("Please enter a message.")
# Audio recorder class
class AudioRecorder(AudioProcessorBase):
def __init__(self):
self.audio_frames = []
def recv(self, frame: av.AudioFrame) -> av.AudioFrame:
self.audio_frames.append(frame.to_ndarray())
return frame
if __name__ == "__main__":
main()