Aasher commited on
Commit
36bebda
·
1 Parent(s): db63cc0
Files changed (5) hide show
  1. .streamlit/config.toml +2 -1
  2. app.py +592 -0
  3. groq_models.py +23 -19
  4. requirements.txt +5 -2
  5. utils.py +20 -1
.streamlit/config.toml CHANGED
@@ -6,4 +6,5 @@ textColor="#f5f8fc"
6
  font="sans serif"
7
 
8
  [server]
9
- runOnSave = true
 
 
6
  font="sans serif"
7
 
8
  [server]
9
+ runOnSave = true
10
+ maxUploadSize = 10
app.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from audio_recorder_streamlit import audio_recorder
3
+ from groq_models import create_groq_agent, groq_chatbot, get_tools, summarizer_model
4
+ from langchain_community.document_loaders import Docx2txtLoader
5
+ from langchain_community.document_loaders import TextLoader
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import base64
9
+ import docx
10
+ from streamlit_lottie import st_lottie
11
+ import json
12
+ from utils import set_safety_settings, about, extract_all_pages_as_images, speech_to_text
13
+ import google.generativeai as genai
14
+ import os, random, validators
15
+ import tempfile
16
+ import asyncio
17
+ import edge_tts
18
+ from dotenv import load_dotenv
19
+ load_dotenv()
20
+
21
+ st.set_page_config(
22
+ page_title="Super GPT",
23
+ page_icon="⚡",
24
+ layout="wide",
25
+ initial_sidebar_state="auto",
26
+ menu_items={"About": about(), "Get Help":"https://www.linkedin.com/in/aasher-kamal-a227a124b/"},
27
+ )
28
+
29
+ ###--- Title ---###
30
+ st.markdown("""
31
+ <h1 style='text-align: center;'>
32
+ <span style='color: #F81F6F;'>Super</span>
33
+ <span style='color: #f5f8fc;'>AI Assistant</span>
34
+ </h1>
35
+ """, unsafe_allow_html=True)
36
+
37
+
38
+ google_models = [
39
+ "gemini-1.5-flash",
40
+ "gemini-1.5-pro",
41
+ ]
42
+
43
+ groq_models = [
44
+ "llama-3.1-8b-instant",
45
+ "llama-3.1-70b-versatile",
46
+ "llama3-70b-8192",
47
+ "llama3-8b-8192",
48
+ "gemma2-9b-it",
49
+ "mixtral-8x7b-32768"
50
+ ]
51
+
52
+ voices = {
53
+ "William":"en-AU-WilliamNeural",
54
+ "James":"en-PH-JamesNeural",
55
+ "Jenny":"en-US-JennyNeural",
56
+ "US Guy":"en-US-GuyNeural",
57
+ "Sawara":"hi-IN-SwaraNeural",
58
+ }
59
+
60
+
61
+ def speech_recoginition():
62
+ pass
63
+
64
+ @st.cache_data
65
+ def load_lottie_file(filepath: str):
66
+ with open(filepath, "r") as f:
67
+ return json.load(f)
68
+
69
+
70
+ async def generate_speech(text, voice):
71
+ communicate = edge_tts.Communicate(text, voice)
72
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
73
+ await communicate.save(temp_file.name)
74
+ temp_file_path = temp_file.name
75
+ return temp_file_path
76
+
77
+
78
+ def get_audio_player(file_path):
79
+ with open(file_path, "rb") as f:
80
+ data = f.read()
81
+ b64 = base64.b64encode(data).decode()
82
+ return f'<audio autoplay="true" src="data:audio/mp3;base64,{b64}">'
83
+
84
+ def generate_voice(text, voice):
85
+ text_to_speak = (text).translate(str.maketrans('', '', '#-*_😊👋😄😁🥳👍🤩😂😎')) # Removing special chars and emojis
86
+ with st.spinner("Generating voice response..."):
87
+ temp_file_path = asyncio.run(generate_speech(text_to_speak, voice))
88
+ audio_player_html = get_audio_player(temp_file_path) # Create an audio player
89
+ st.markdown(audio_player_html, unsafe_allow_html=True)
90
+ os.unlink(temp_file_path) # Clean up the temporary audio file
91
+
92
+
93
+ def get_llm_info(available_models):
94
+ with st.sidebar:
95
+ tip =tip = "Select Gemini models if you require multi-modal capabilities (text, image, audio and video inputs)"
96
+ model = st.selectbox("Choose LLM:", available_models, help=tip)
97
+
98
+ model_type = None
99
+ if model.startswith(("llama", "gemma", "mixtral")): model_type = "groq"
100
+ elif model.startswith("gemini"): model_type = "google"
101
+
102
+ with st.popover("⚙️Model Parameters", use_container_width=True):
103
+ temp = st.slider("Temperature:", min_value=0.0,
104
+ max_value=2.0, value=0.5, step=0.5)
105
+
106
+ max_tokens = st.slider("Maximum Tokens:", min_value=100,
107
+ max_value=2000, value=400, step=200)
108
+ return model, model_type, temp, max_tokens
109
+
110
+
111
+ ###--- Function to convert base64 to temp file ---###
112
+ def base64_to_temp_file(base64_string, unique_name, file_extension):
113
+ base64_string = base64_string.split(",")[1]
114
+ file_bytes = BytesIO(base64.b64decode(base64_string))
115
+ temp_file_path = f"{unique_name}.{file_extension}"
116
+ with open(temp_file_path, "wb") as temp_file:
117
+ temp_file.write(file_bytes.read())
118
+ return temp_file_path
119
+
120
+ ##----Preparing messages for Gemini----##
121
+ def messages_to_gemini(messages):
122
+ gemini_messages = []
123
+ prev_role = None
124
+ uploaded_files = set([file.display_name.split(".")[0] for file in genai.list_files()])
125
+
126
+ for message in messages:
127
+ if prev_role and (prev_role == message["role"]):
128
+ gemini_message = gemini_messages[-1]
129
+ else:
130
+ gemini_message = {
131
+ "role": "model" if message["role"] == "assistant" else "user",
132
+ "parts": [],
133
+ }
134
+
135
+ for content in message["content"]:
136
+ if content["type"] in ["text","docx_file"]:
137
+ gemini_message["parts"].append(content[content["type"]])
138
+
139
+ elif content["type"] == "image_url":
140
+ gemini_message["parts"].append(base64_to_image(content["image_url"]["url"]))
141
+
142
+ elif content["type"] in ["video_file", "audio_file", "speech_input"]:
143
+ file_name = content['unique_name']
144
+
145
+ if file_name not in uploaded_files:
146
+ temp_file_path = base64_to_temp_file(content[content["type"]], file_name, "mp4" if content["type"] == "video_file" else "wav")
147
+
148
+ with st.spinner(f"Sending {content['type'].replace('_', ' ')} to Gemini..."):
149
+ gemini_message["parts"].append(genai.upload_file(path=temp_file_path))
150
+ os.remove(temp_file_path)
151
+
152
+ elif content["type"] == "pdf_file":
153
+ if content['pdf_file'].split(".")[0] not in uploaded_files:
154
+ with st.spinner("Sending your PDF to Gemini..."):
155
+ gemini_message["parts"].append(genai.upload_file(path=content['pdf_file']))
156
+ os.remove(content['pdf_file'])
157
+
158
+ if prev_role != message["role"]:
159
+ gemini_messages.append(gemini_message)
160
+
161
+ prev_role = message["role"]
162
+
163
+ return gemini_messages
164
+
165
+
166
+ ##-- Converting base64 to image ---##
167
+ def base64_to_image(base64_string):
168
+ base64_string = base64_string.split(",")[1]
169
+
170
+ return Image.open(BytesIO(base64.b64decode(base64_string)))
171
+
172
+
173
+ def add_pdf_docx_file_to_messages():
174
+ if st.session_state.pdf_docx_uploaded:
175
+ file_type = st.session_state.pdf_docx_uploaded.type
176
+ if file_type == "application/pdf":
177
+ # Save the PDF file
178
+ pdf_id = random.randint(1000, 9999)
179
+ pdf_filename = f"pdf_{pdf_id}.pdf"
180
+ with open(pdf_filename, "wb") as f:
181
+ f.write(st.session_state.pdf_docx_uploaded.read())
182
+
183
+ # Add the PDF file to session_state messages
184
+ st.session_state.messages.append(
185
+ {
186
+ "role": "user",
187
+ "content": [{
188
+ "type": "pdf_file",
189
+ "pdf_file": pdf_filename,
190
+ }]
191
+ }
192
+ )
193
+ else:
194
+ file_content = st.session_state.pdf_docx_uploaded
195
+ text = ""
196
+ doc = docx.Document(file_content)
197
+ full_text = []
198
+
199
+ for para in doc.paragraphs:
200
+ full_text.append(para.text)
201
+ text += " ".join(full_text)
202
+
203
+ # Add the PDF file to session_state messages
204
+ st.session_state.messages.append(
205
+ {
206
+ "role": "user",
207
+ "content": [{
208
+ "type": "docx_file",
209
+ "docx_file": text,
210
+ }]
211
+ }
212
+ )
213
+
214
+
215
+ def save_uploaded_video(video_file, file_path):
216
+ with open(file_path, "wb") as f:
217
+ f.write(video_file.read())
218
+
219
+ ##--- Function for adding media files to session_state messages ---###
220
+ def add_media_files_to_messages():
221
+ if st.session_state.uploaded_file:
222
+ file_type = st.session_state.uploaded_file.type
223
+ file_content = st.session_state.uploaded_file.getvalue()
224
+
225
+ if file_type.startswith("image"):
226
+ img = base64.b64encode(file_content).decode()
227
+ st.session_state.messages.append(
228
+ {
229
+ "role": "user",
230
+ "content": [{
231
+ "type": "image_url",
232
+ "image_url": {"url": f"data:{file_type};base64,{img}"}
233
+ }]
234
+ }
235
+ )
236
+ elif file_type == "video/mp4":
237
+ video_base64 = base64.b64encode(file_content).decode()
238
+ unique_id = random.randint(1000, 9999)
239
+ # file_name = st.session_state.uploaded_file.name
240
+ # file_path = os.path.join(tempfile.gettempdir(), file_name)
241
+ # save_uploaded_video(st.session_state.uploaded_file, file_path)
242
+
243
+ st.session_state.messages.append(
244
+ {
245
+ "role": "user",
246
+ "content": [{
247
+ "type": "video_file",
248
+ "video_file": f"data:{file_type};base64,{video_base64}",
249
+ "unique_name": f"temp_{unique_id}"
250
+ }]
251
+ }
252
+ )
253
+ elif file_type.startswith("audio"):
254
+ audio_base64 = base64.b64encode(file_content).decode()
255
+ unique_id = random.randint(1000, 9999)
256
+ st.session_state.messages.append(
257
+ {
258
+ "role": "user",
259
+ "content": [{
260
+ "type": "audio_file",
261
+ "audio_file": f"data:{file_type};base64,{audio_base64}",
262
+ "unique_name": f"temp_{unique_id}"
263
+ }]
264
+ }
265
+ )
266
+
267
+ ###--- FUNCTION TO ADD CAMERA IMAGE TO MESSAGES ---##
268
+ def add_camera_img_to_messages():
269
+ if "camera_img" in st.session_state and st.session_state.camera_img:
270
+ img = base64.b64encode(st.session_state.camera_img.getvalue()).decode()
271
+ st.session_state.messages.append(
272
+ {
273
+ "role": "user",
274
+ "content": [{
275
+ "type": "image_url",
276
+ "image_url": {"url": f"data:image/jpeg;base64,{img}"}
277
+ }]
278
+ }
279
+ )
280
+
281
+ ##--- FUNCTION TO RESET CONVERSATION ---##
282
+ def reset_conversation():
283
+ if "messages" in st.session_state and len(st.session_state.messages) > 0:
284
+ st.session_state.pop("messages", None)
285
+ if "groq_chat_history" in st.session_state and len(st.session_state.groq_chat_history) > 1:
286
+ st.session_state.pop("groq_chat_history", None)
287
+
288
+ for file in genai.list_files():
289
+ genai.delete_file(file.name)
290
+
291
+ # Reset the uploaded files list
292
+ if "uploaded_files" in st.session_state:
293
+ st.session_state.pop("uploaded_files", None)
294
+
295
+ if "pdf_docx_uploaded" in st.session_state:
296
+ st.session_state.pop("pdf_docx_uploaded", None)
297
+
298
+ ##--- FUNCTION TO STREAM GEMINI RESPONSE ---##
299
+ def stream_gemini_response(model_params, api_key):
300
+ response_message = ""
301
+
302
+ genai.configure(api_key=api_key)
303
+ model = genai.GenerativeModel(
304
+ model_name = model_params["model"],
305
+ generation_config={
306
+ "temperature": model_params["temperature"],
307
+ "max_output_tokens": model_params["max_tokens"],
308
+ },
309
+ safety_settings=set_safety_settings(),
310
+ system_instruction="""You are a helpful assistant who asnwers user's questions professionally and politely."""
311
+ )
312
+ gemini_messages = messages_to_gemini(st.session_state.messages)
313
+
314
+ for chunk in model.generate_content(contents=gemini_messages, stream=True):
315
+ chunk_text = chunk.text or ""
316
+ response_message += chunk_text
317
+ yield chunk_text
318
+
319
+ st.session_state.messages.append({
320
+ "role": "assistant",
321
+ "content": [
322
+ {
323
+ "type": "text",
324
+ "text": response_message,
325
+ }
326
+ ]})
327
+
328
+ if "summarize" not in st.session_state:
329
+ st.session_state.summarize = False
330
+ ##--- API KEYS ---##
331
+ with st.sidebar:
332
+ st.logo("logo.png")
333
+ api_cols = st.columns(2)
334
+ with api_cols[0]:
335
+ with st.popover("🔐 Groq", use_container_width=True):
336
+ groq_api_key = st.text_input("Click [here](https://console.groq.com/keys) to get your Groq API key", value=os.getenv("GROQ_API_KEY") , type="password")
337
+
338
+ with api_cols[1]:
339
+ with st.popover("🔐 Google", use_container_width=True):
340
+ google_api_key = st.text_input("Click [here](https://aistudio.google.com/app/apikey) to get your Google API key", value=os.getenv("GOOGLE_API_KEY") , type="password")
341
+
342
+ ##--- API KEY CHECK ---##
343
+ if (groq_api_key == "" or groq_api_key is None or "gsk" not in groq_api_key) and (google_api_key == "" or google_api_key is None or "AIza" not in google_api_key):
344
+ st.warning("Please Add an API Key to proceed.")
345
+
346
+ ####--- LLM SIDEBAR ---###
347
+ else:
348
+ with st.sidebar:
349
+ st.divider()
350
+ columns = st.columns(2)
351
+ # animation
352
+ with columns[0]:
353
+ lottie_animation = load_lottie_file("animation.json")
354
+ if lottie_animation:
355
+ st_lottie(lottie_animation, height=100, width=100, quality="high", key="lottie_anim")
356
+
357
+ with columns[1]:
358
+ if st.toggle("Voice Response"):
359
+ response_voice = st.selectbox("Available Voices:", options=voices.keys(), key="voice_response")
360
+
361
+ available_models = [] + (google_models if google_api_key else []) + (groq_models if groq_api_key else [])
362
+ model, model_type, temperature, max_tokens = get_llm_info(available_models)
363
+
364
+ model_params = {
365
+ "model": model,
366
+ "temperature": temperature,
367
+ "max_tokens": max_tokens
368
+ }
369
+ st.divider()
370
+
371
+ ###---- Google Gemini Sidebar Customization----###
372
+ if model_type == "google":
373
+ st.write("Upload a file or take a picture")
374
+
375
+ media_cols = st.columns(2)
376
+
377
+ with media_cols[0]:
378
+ with st.popover("📁 Upload", use_container_width=True):
379
+ st.file_uploader(
380
+ "Upload an image, audio or a video",
381
+ type=["png", "jpg", "jpeg", "wav", "mp3", "mp4"],
382
+ accept_multiple_files=False,
383
+ key="uploaded_file",
384
+ on_change=add_media_files_to_messages,
385
+ )
386
+
387
+ with media_cols[1]:
388
+ with st.popover("📷 Camera", use_container_width=True):
389
+ activate_camera = st.checkbox("Activate camera")
390
+ if activate_camera:
391
+ st.camera_input(
392
+ "Take a picture",
393
+ key="camera_img",
394
+ on_change=add_camera_img_to_messages,
395
+ )
396
+ st.divider()
397
+ tip = "If you upload a PDF, it will be sent to LLM."
398
+ pdf_upload = st.file_uploader("Upload a PDF", type=["pdf", "docx"], key="pdf_docx_uploaded", on_change=add_pdf_docx_file_to_messages, help=tip)
399
+
400
+ ###---- Groq Models Sidebar Customization----###
401
+ else:
402
+ groq_llm_type = st.radio(label="Select the LLM type:", key="groq_llm_type",options=["Agent", "Chatbot", "Summarizer"], horizontal=True)
403
+ if groq_llm_type == "Summarizer":
404
+ url = st.text_input("Enter YT video or Webpage URL:", key="url_to_summarize",
405
+ help="Only Youtube videos having captions can be summarized.")
406
+
407
+ summarize_button = st.button("Summarize", type="primary", use_container_width=True, key="summarize")
408
+
409
+ elif groq_llm_type == "Agent":
410
+ tools = st.multiselect("Select Tools for Agent",key="selected_tools",default=["Wikipedia", "ArXiv", "DuckDuckGo Search"],
411
+ options=["Wikipedia", "ArXiv", "DuckDuckGo Search"])
412
+
413
+
414
+ ######----- Main Interface -----#######
415
+ chat_col1, chat_col2 = st.columns([1,4])
416
+
417
+ with chat_col1:
418
+ ###--- Audio Recording ---###
419
+ audio_bytes = audio_recorder("Speak",
420
+ pause_threshold=3,
421
+ neutral_color="#f5f8fc",
422
+ recording_color="#f81f6f",
423
+ icon_name="microphone-lines",
424
+ icon_size="3x")
425
+
426
+ ###--- Reset Conversation ---###
427
+ st.button(
428
+ "🗑 Reset",
429
+ use_container_width=True,
430
+ on_click=reset_conversation,
431
+ help="If clicked, conversation will be reset.",
432
+ )
433
+ ###--- Session state variable ---###
434
+ if "pdf_docx_uploaded" not in st.session_state:
435
+ st.session_state.pdf_docx_uploaded = None
436
+
437
+ if st.session_state.pdf_docx_uploaded:
438
+ file_name = st.session_state.pdf_docx_uploaded.name
439
+ st.info(f"Your file :green['{file_name}'] has been uploaded!")
440
+
441
+
442
+ if "messages" not in st.session_state:
443
+ st.session_state.messages = []
444
+ if "uploaded_files" not in st.session_state:
445
+ st.session_state.uploaded_files = []
446
+ if "groq_chat_history" not in st.session_state:
447
+ st.session_state.groq_chat_history = []
448
+
449
+ ###-- Handle speech input --###
450
+ speech_file_added = False
451
+ if "prev_speech_hash" not in st.session_state:
452
+ st.session_state.prev_speech_hash = None
453
+
454
+ if audio_bytes and st.session_state.prev_speech_hash != hash(audio_bytes):
455
+ st.session_state.prev_speech_hash = hash(audio_bytes)
456
+ speech_base64 = base64.b64encode(audio_bytes).decode()
457
+ unique_id = random.randint(1000, 9999)
458
+
459
+ st.session_state.messages.append(
460
+ {
461
+ "role": "user",
462
+ "content": [{
463
+ "type": "speech_input",
464
+ "speech_input": f"data:audio/wav;base64,{speech_base64}",
465
+ "unique_name": f"temp_{unique_id}"
466
+ }]
467
+ }
468
+ )
469
+ speech_file_added = True
470
+
471
+
472
+ with chat_col2:
473
+ message_container = st.container(height=400, border=False)
474
+
475
+ for message in st.session_state.messages:
476
+ avatar = "assistant.png" if message["role"] == "assistant" else "user.png"
477
+ valid_content = [
478
+ content for content in message["content"]
479
+ if not (
480
+ (content["type"] == "text" and content["text"] == "Please Answer the Question asked in the audio.") or
481
+ content["type"] in ["pdf_file","docx_file"]
482
+ )
483
+ ]
484
+ if valid_content:
485
+ with message_container.chat_message(message["role"], avatar=avatar):
486
+ for content in message["content"]:
487
+ if content["type"] == "text":
488
+ st.markdown(content["text"])
489
+ elif content["type"] == "image_url":
490
+ st.image(content["image_url"]["url"])
491
+ elif content["type"] == "video_file":
492
+ st.video(content["video_file"])
493
+ elif content["type"] == "audio_file":
494
+ st.audio(content["audio_file"], autoplay=True)
495
+ elif content["type"] == "speech_input":
496
+ st.audio(content["speech_input"])
497
+
498
+ for msg in st.session_state.groq_chat_history:
499
+ avatar = "assistant.png" if msg["role"] == "assistant" else "user.png"
500
+ with message_container.chat_message(msg["role"], avatar=avatar):
501
+ st.markdown(msg['content'])
502
+
503
+ ###---- Summarizer model------###
504
+ if model_type == "groq" and groq_llm_type == "Summarizer":
505
+ if st.session_state.summarize:
506
+ with message_container.chat_message("assistant", avatar="assistant.png"):
507
+ if not url.strip():
508
+ st.error("Please enter a URL")
509
+ elif not validators.url(url):
510
+ st.error("Please enter a valid URL")
511
+ else:
512
+ with st.spinner("Summarizing..."):
513
+ final_response = summarizer_model(model_params=model_params, api_key=groq_api_key, url=url)
514
+ st.markdown(final_response)
515
+ st.session_state.groq_chat_history.append({"role": "assistant", "content": final_response})
516
+
517
+ ###----- User Question -----###
518
+ else:
519
+ if prompt:= st.chat_input("Type you question", key="question") or speech_file_added:
520
+ ###------- GROQ MODELS --------###
521
+ if model_type == "groq":
522
+
523
+ if not speech_file_added:
524
+ message_container.chat_message("user", avatar="user.png").markdown(prompt)
525
+ st.session_state.groq_chat_history.append({"role": "user", "content": prompt})
526
+ else:
527
+ speech_to_text = speech_to_text(audio_bytes)
528
+ message_container.chat_message("user", avatar="user.png").markdown(speech_to_text)
529
+ st.session_state.groq_chat_history.append({"role": "user", "content": speech_to_text})
530
+
531
+ with message_container.chat_message("assistant", avatar="assistant.png"):
532
+
533
+ try:
534
+ if groq_llm_type == "Chatbot":
535
+ final_response = st.write_stream(groq_chatbot(model_params=model_params, api_key=groq_api_key,
536
+ question=prompt, chat_history=st.session_state.groq_chat_history))
537
+
538
+ elif groq_llm_type == "Agent":
539
+ final_response = create_groq_agent(model_params=model_params, api_key=groq_api_key,
540
+ question=prompt,
541
+ tools=get_tools(tools),
542
+ chat_history=st.session_state.groq_chat_history,)
543
+
544
+ st.markdown(final_response)
545
+ st.session_state.groq_chat_history.append({"role": "assistant", "content": final_response})
546
+
547
+ if "voice_response" in st.session_state and st.session_state.voice_response:
548
+ response_voice = st.session_state.voice_response
549
+ generate_voice(final_response, voices[response_voice])
550
+
551
+ except Exception as e:
552
+ st.error(f"An error occurred: {e}", icon="❌")
553
+
554
+ ###-------- GEMINI MODELS -------###
555
+ else:
556
+ if not speech_file_added:
557
+ message_container.chat_message("user", avatar="user.png").markdown(prompt)
558
+
559
+ st.session_state.messages.append(
560
+ {
561
+ "role": "user",
562
+ "content": [{
563
+ "type": "text",
564
+ "text": prompt,
565
+ }]
566
+ }
567
+ )
568
+ ###----Google Gemini Response----###
569
+ else:
570
+ st.session_state.messages.append(
571
+ {
572
+ "role": "user",
573
+ "content": [{
574
+ "type": "text",
575
+ "text": "Please Answer the Question asked in the audio.",
576
+ }]
577
+ }
578
+ )
579
+
580
+ ###----- Generate response -----###
581
+ with message_container.chat_message("assistant", avatar="assistant.png"):
582
+ try:
583
+ final_response = st.write_stream(stream_gemini_response(model_params=model_params, api_key= google_api_key))
584
+
585
+ if "voice_response" in st.session_state and st.session_state.voice_response:
586
+ response_voice = st.session_state.voice_response
587
+ generate_voice(final_response, voices[response_voice])
588
+
589
+ except Exception as e:
590
+ st.error(f"An error occurred: {e}", icon="❌")
591
+
592
+
groq_models.py CHANGED
@@ -80,25 +80,29 @@ def create_groq_agent(model_params, api_key, tools, question, chat_history):
80
  return response['output']
81
 
82
 
83
- def get_tools():
84
- wikipedia = WikipediaAPIWrapper(top_k_results=2, doc_content_chars_max=500)
85
- wikipedia_tool = Tool(name="Wikipedia",
86
- func=wikipedia.run,
87
- description="A useful tool for searching the Internet to find information on world events, issues, dates, years, etc.")
88
- arxiv = ArxivAPIWrapper(top_k_results=2, doc_content_chars_max=500)
89
- arxiv_tool = Tool(name="ArXiv",
90
- func=arxiv.run,
91
- description="A useful tool for searching scientific and research papers."
92
-
93
- )
94
- search = DuckDuckGoSearchRun()
95
- search_tool = Tool(
96
- name="DuckDuckGo Search",
97
- func=search.run,
98
- description="Useful for when you need to search the internet to find latest information, facts and figures that another tool can't find.",
99
- )
100
-
101
- return [arxiv_tool, wikipedia_tool, search_tool]
 
 
 
 
102
 
103
  def summarizer_model(model_params, api_key, url):
104
  llm = ChatGroq(model=model_params['model'], api_key=api_key,
 
80
  return response['output']
81
 
82
 
83
+ def get_tools(selected_tools):
84
+ # Define all available tools
85
+ tools = {
86
+ "Wikipedia": Tool(
87
+ name="Wikipedia",
88
+ func=WikipediaAPIWrapper(top_k_results=2, doc_content_chars_max=500).run,
89
+ description="A useful tool for searching the Internet to find information on world events, issues, dates, years, etc."
90
+ ),
91
+ "ArXiv": Tool(
92
+ name="ArXiv",
93
+ func=ArxivAPIWrapper(top_k_results=2, doc_content_chars_max=500).run,
94
+ description="A useful tool for searching scientific and research papers."
95
+ ),
96
+ "DuckDuckGo Search": Tool(
97
+ name="DuckDuckGo Search",
98
+ func=DuckDuckGoSearchRun().run,
99
+ description="Useful for when you need to search the internet to find latest information, facts and figures that another tool can't find."
100
+ )
101
+ }
102
+
103
+ # Filter and return only the tools selected by the user
104
+ return [tools[tool_name] for tool_name in selected_tools]
105
+
106
 
107
  def summarizer_model(model_params, api_key, url):
108
  llm = ChatGroq(model=model_params['model'], api_key=api_key,
requirements.txt CHANGED
@@ -10,10 +10,13 @@ langchain
10
  langchain-groq
11
  langchain_community
12
  pypdf
13
- pdfplumber
14
  edge-tts
15
  arxiv
16
  wikipedia
17
  duckduckgo-search
18
  langchainhub
19
- validators
 
 
 
 
 
10
  langchain-groq
11
  langchain_community
12
  pypdf
 
13
  edge-tts
14
  arxiv
15
  wikipedia
16
  duckduckgo-search
17
  langchainhub
18
+ validators
19
+ docx
20
+ SpeechRecognition
21
+ youtube-transcript-api
22
+ pytube
utils.py CHANGED
@@ -2,6 +2,8 @@ import streamlit as st
2
  from streamlit_vertical_slider import vertical_slider
3
  import pdfplumber
4
  from langchain_core.prompts import ChatPromptTemplate
 
 
5
 
6
  @st.dialog("Confirm Selection 👇", width="large")
7
  def visualize_display_page(selection_dict):
@@ -84,4 +86,21 @@ def set_safety_settings():
84
  },
85
  ]
86
 
87
- return safety_settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from streamlit_vertical_slider import vertical_slider
3
  import pdfplumber
4
  from langchain_core.prompts import ChatPromptTemplate
5
+ import speech_recognition as sr
6
+ import tempfile
7
 
8
  @st.dialog("Confirm Selection 👇", width="large")
9
  def visualize_display_page(selection_dict):
 
86
  },
87
  ]
88
 
89
+ return safety_settings
90
+
91
+ def speech_to_text(audio_bytes):
92
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as recording:
93
+ recording.write(audio_bytes)
94
+ temp_file_path = recording.name
95
+
96
+ r = sr.Recognizer()
97
+ with sr.AudioFile(temp_file_path) as source:
98
+ recorded_voice = r.record(source)
99
+
100
+ try:
101
+ text = r.recognize_google(recorded_voice, language="en")
102
+ return text
103
+ except sr.UnknownValueError as e:
104
+ st.error(e)
105
+ except sr.RequestError as e:
106
+ print("could not request result from google speech recognition service: {0}".format(e))