Dua Rajper commited on
Commit
d2af615
·
verified ·
1 Parent(s): 9a6ae7d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from groq import Groq, APIConnectionError, AuthenticationError
4
+ from transformers import (
5
+ pipeline,
6
+ AutoTokenizer,
7
+ AutoModelForQuestionAnswering,
8
+ AutoProcessor,
9
+ AutoModelForSpeechSeq2Seq,
10
+ )
11
+ from espnet2.bin.tts_inference import Text2Speech
12
+ from PIL import Image
13
+ import easyocr
14
+ import soundfile as sf
15
+ from pydub import AudioSegment
16
+ import io
17
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode, AudioProcessorBase
18
+ import av
19
+ import numpy as np
20
+
21
+ # Load Groq API key from environment variables
22
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
23
+ if not GROQ_API_KEY:
24
+ st.error("Groq API key not found. Please add it to the Hugging Face Space Secrets.")
25
+ st.stop()
26
+
27
+ # Initialize Groq client
28
+ groq_client = Groq(api_key=GROQ_API_KEY)
29
+
30
+ # OCR Function
31
+ def extract_text_from_image(image):
32
+ reader = easyocr.Reader(['en'])
33
+ result = reader.readtext(image)
34
+ extracted_text = " ".join([detection[1] for detection in result])
35
+ return extracted_text
36
+
37
+ # Question Answering Function (DistilBERT)
38
+ @st.cache_resource
39
+ def load_qa_model():
40
+ model_name = "distilbert/distilbert-base-cased-distilled-squad"
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
43
+ nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)
44
+ return nlp
45
+
46
+ def answer_question(context, question, qa_model):
47
+ result = qa_model({'question': question, 'context': context})
48
+ return result['answer']
49
+
50
+ # Load models for voice chatbot
51
+ @st.cache_resource
52
+ def load_voice_models():
53
+ # Speech-to-Text
54
+ processor = AutoProcessor.from_pretrained("openai/whisper-small")
55
+ stt_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small")
56
+ stt_pipe = pipeline(
57
+ "automatic-speech-recognition",
58
+ model=stt_model,
59
+ tokenizer=processor.tokenizer,
60
+ feature_extractor=processor.feature_extractor,
61
+ return_timestamps=True # Enable timestamps for long-form audio
62
+ )
63
+ # Text-to-Speech
64
+ tts_model = Text2Speech.from_pretrained("espnet/espnet_tts_vctk_espnet_spk_voxceleb12_rawnet")
65
+ return stt_pipe, tts_model
66
+
67
+ # Groq API Function
68
+ def groq_chat(prompt):
69
+ try:
70
+ chat_completion = groq_client.chat.completions.create(
71
+ messages=[{"role": "user", "content": prompt}],
72
+ model="llama-3.3-70b-versatile",
73
+ )
74
+ return chat_completion.choices[0].message.content
75
+ except APIConnectionError as e:
76
+ return f"Groq API Connection Error: {e}"
77
+ except AuthenticationError as e:
78
+ return f"Groq API Authentication Error: {e}"
79
+ except Exception as e:
80
+ return f"General Groq API Error: {e}"
81
+
82
+ # Streamlit App
83
+ def main():
84
+ st.title("Multi-Modal Chatbot: Image Text & Voice")
85
+
86
+ # Sidebar for mode selection
87
+ mode = st.sidebar.radio("Select Mode", ["Image Text & QA", "Voice Chatbot"])
88
+
89
+ if mode == "Image Text & QA":
90
+ # Image Text Extraction & QA
91
+ st.header("Image Text Extraction & Question Answering")
92
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
93
+
94
+ if uploaded_file is not None:
95
+ image = Image.open(uploaded_file)
96
+ st.image(image, caption="Uploaded Image", use_container_width=True)
97
+
98
+ if st.button("Extract Text and Enable Question Answering"):
99
+ with st.spinner("Extracting text..."):
100
+ extracted_text = extract_text_from_image(image)
101
+ st.write("Extracted Text:")
102
+ st.write(extracted_text)
103
+
104
+ qa_model = load_qa_model()
105
+
106
+ question = st.text_input("Ask a question about the image text:")
107
+ if st.button("Answer"):
108
+ if question:
109
+ with st.spinner("Answering..."):
110
+ answer = answer_question(extracted_text, question, qa_model)
111
+ st.write("Answer:", answer)
112
+ else:
113
+ st.warning("Please enter a question.")
114
+
115
+ elif mode == "Voice Chatbot":
116
+ # Voice Chatbot
117
+ st.header("Voice-Enabled Chatbot")
118
+
119
+ # Audio recorder
120
+ st.write("Record your voice:")
121
+ webrtc_ctx = webrtc_streamer(
122
+ key="audio-recorder",
123
+ mode=WebRtcMode.SENDONLY,
124
+ audio_processor_factory=AudioRecorder,
125
+ media_stream_constraints={"audio": True, "video": False},
126
+ )
127
+
128
+ if webrtc_ctx.audio_processor:
129
+ st.write("Recording... Press 'Stop' to finish recording.")
130
+ # Save recorded audio to a WAV file
131
+ if st.button("Stop and Process Recording"):
132
+ audio_frames = webrtc_ctx.audio_processor.audio_frames
133
+ if audio_frames:
134
+ # Combine audio frames into a single array
135
+ audio_data = np.concatenate(audio_frames)
136
+ # Save as WAV file
137
+ sf.write("recorded_audio.wav", audio_data, samplerate=16000)
138
+ st.success("Recording saved as recorded_audio.wav")
139
+ # Process the recorded audio
140
+ speech, _ = sf.read("recorded_audio.wav")
141
+ output = stt_pipe(speech) # Transcribe with timestamps
142
+ # Debug: Print the transcribed text
143
+ st.write("Transcribed Text:", output['text'])
144
+ # Display the text with timestamps (optional)
145
+ if 'chunks' in output:
146
+ st.write("Transcribed Text with Timestamps:")
147
+ for chunk in output['chunks']:
148
+ st.write(f"{chunk['timestamp'][0]:.2f} - {chunk['timestamp'][1]:.2f}: {chunk['text']}")
149
+ # Generate response using Groq API
150
+ try:
151
+ # Debug: Print the input text
152
+ st.write("Input Text:", output['text'])
153
+ chat_completion = groq_client.chat.completions.create(
154
+ messages=[{"role": "user", "content": output['text']}],
155
+ model="mixtral-8x7b-32768",
156
+ temperature=0.5,
157
+ max_tokens=1024,
158
+ )
159
+ # Debug: Print the API response
160
+ st.write("API Response:", chat_completion)
161
+ # Extract the generated response
162
+ response = chat_completion.choices[0].message.content
163
+ st.write("Generated Response:", response)
164
+ # Convert response to speech
165
+ speech, *_ = tts_model(response, spembs=tts_model.spembs[0]) # Use the first speaker embedding
166
+ # Debug: Print the TTS output
167
+ st.write("TTS Output:", speech)
168
+ # Save and play the speech
169
+ sf.write("response.wav", speech, 22050)
170
+ st.audio("response.wav")
171
+ except Exception as e:
172
+ st.error(f"Error generating response: {e}")
173
+ else:
174
+ st.error("No audio recorded. Please try again.")
175
+
176
+ # Groq Chat Section (Common for both modes)
177
+ st.subheader("General Chat (Powered by Groq)")
178
+ groq_prompt = st.text_input("Enter your message:")
179
+ if st.button("Send"):
180
+ if groq_prompt:
181
+ with st.spinner("Generating response..."):
182
+ groq_response = groq_chat(groq_prompt)
183
+ st.write("Response:", groq_response)
184
+ else:
185
+ st.warning("Please enter a message.")
186
+
187
+ # Audio recorder class
188
+ class AudioRecorder(AudioProcessorBase):
189
+ def __init__(self):
190
+ self.audio_frames = []
191
+
192
+ def recv(self, frame: av.AudioFrame) -> av.AudioFrame:
193
+ self.audio_frames.append(frame.to_ndarray())
194
+ return frame
195
+
196
+ if __name__ == "__main__":
197
+ main()