Aasher commited on
Commit
4fff95e
Β·
1 Parent(s): 5486094
Files changed (6) hide show
  1. .gitignore +1 -0
  2. groq_models.py +123 -0
  3. requirements.txt +6 -1
  4. test4.py +3 -3
  5. test5.py +565 -0
  6. utils.py +1 -0
.gitignore CHANGED
@@ -167,6 +167,7 @@ code_not_using_vertex.py
167
  test.py
168
  test2.py
169
  test3.py
 
170
  tts.py
171
  files_upload.py
172
  main.py
 
167
  test.py
168
  test2.py
169
  test3.py
170
+ test4.py
171
  tts.py
172
  files_upload.py
173
  main.py
groq_models.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_groq import ChatGroq
2
+ from langchain_core.output_parsers import StrOutputParser
3
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
4
+ from langchain_community.document_loaders import YoutubeLoader, WebBaseLoader
5
+ from langchain.chains.summarize import load_summarize_chain
6
+ from langchain_core.tools import Tool
7
+ from langchain_community.tools import DuckDuckGoSearchRun
8
+ from langchain.agents import create_react_agent
9
+ from langchain.agents import AgentExecutor
10
+ from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
11
+ from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
12
+ import streamlit as st
13
+
14
+ def groq_chatbot(model_params, question, api_key, chat_history):
15
+ llm = ChatGroq(model=model_params['model'], api_key=api_key,
16
+ temperature=model_params["temperature"],
17
+ max_tokens=model_params['max_tokens']
18
+ )
19
+
20
+ system_template = (
21
+ """Given a chat history and the latest user question
22
+ which might reference context in the chat history,
23
+ Answer the user question in a polite and professional manner."""
24
+ )
25
+ prompt = ChatPromptTemplate.from_messages(
26
+ [
27
+ ("system", system_template),
28
+ MessagesPlaceholder(variable_name="chat_history"),
29
+ ("user", "Questioin: {question}")
30
+ ]
31
+ )
32
+ chain = prompt | llm | StrOutputParser()
33
+
34
+ return chain.stream({"question": question, "chat_history": chat_history})
35
+
36
+
37
+ def get_prompt():
38
+ prompt = ChatPromptTemplate.from_template("""
39
+ Answer the following user questions as best you can. Use the available tools to find the answer.
40
+ You have access to the following tools:\n
41
+ {tools}\n\n
42
+ To use a tool, please use the following format:
43
+ ```
44
+ Thought: Do I need to use a tool? Yes
45
+ Action: the action to take, should be one of [{tool_names}]
46
+ Action Input: the input to the action
47
+ Observation: the result of the action
48
+ ```
49
+ If one tool doesn't give the relavant information, use another tool.
50
+ When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
51
+
52
+ ```
53
+ Thought: Do I need to use a tool? No
54
+ Final Answer: [your response here]
55
+ ```
56
+ Begin!
57
+
58
+ Previous conversation history:
59
+ {chat_history}
60
+ New input: {input}
61
+
62
+ {agent_scratchpad}
63
+ """)
64
+ return prompt
65
+
66
+
67
+ def create_groq_agent(model_params, api_key, tools, question, chat_history):
68
+
69
+ llm = ChatGroq(model=model_params['model'], api_key=api_key,
70
+ temperature=model_params["temperature"],
71
+ )
72
+ prompt = get_prompt()
73
+
74
+ agent = create_react_agent(llm, tools, prompt)
75
+
76
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True, max_iterations=7)
77
+ st_callback = StreamlitCallbackHandler(st.container())
78
+
79
+ response = agent_executor.invoke({"input":question, "chat_history":chat_history}, {"callbacks": [st_callback]})
80
+ return response['output']
81
+
82
+
83
+ def get_tools():
84
+ wikipedia = WikipediaAPIWrapper(top_k_results=2, doc_content_chars_max=500)
85
+ wikipedia_tool = Tool(name="Wikipedia",
86
+ func=wikipedia.run,
87
+ description="A useful tool for searching the Internet to find information on world events, issues, dates, years, etc.")
88
+ arxiv = ArxivAPIWrapper(top_k_results=2, doc_content_chars_max=500)
89
+ arxiv_tool = Tool(name="ArXiv",
90
+ func=arxiv.run,
91
+ description="A useful tool for searching scientific and research papers."
92
+
93
+ )
94
+ search = DuckDuckGoSearchRun()
95
+ search_tool = Tool(
96
+ name="DuckDuckGo Search",
97
+ func=search.run,
98
+ description="Useful for when you need to search the internet to find latest information, facts and figures that another tool can't find.",
99
+ )
100
+
101
+ return [arxiv_tool, wikipedia_tool, search_tool]
102
+
103
+ def summarizer_model(model_params, api_key, url):
104
+ llm = ChatGroq(model=model_params['model'], api_key=api_key,
105
+ temperature=model_params["temperature"],
106
+ max_tokens=model_params['max_tokens']
107
+ )
108
+
109
+ if "youtube.com" in url:
110
+ loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)
111
+ else:
112
+ loader = WebBaseLoader(web_path=url)
113
+
114
+ data = loader.load()
115
+
116
+ prompt_template = """Provide a summary of the following content in proper markdown:
117
+ Content:\n{text}"""
118
+
119
+ prompt = PromptTemplate(input_variables=["text"], template=prompt_template)
120
+
121
+ chain = load_summarize_chain(llm=llm, chain_type="stuff", prompt=prompt)
122
+ output = chain.run(data)
123
+ return output
requirements.txt CHANGED
@@ -11,4 +11,9 @@ langchain-groq
11
  langchain_community
12
  pypdf
13
  pdfplumber
14
- edge-tts
 
 
 
 
 
 
11
  langchain_community
12
  pypdf
13
  pdfplumber
14
+ edge-tts
15
+ arxiv
16
+ wikipedia
17
+ duckduckgo-search
18
+ langchainhub
19
+ validators
test4.py CHANGED
@@ -355,7 +355,7 @@ else:
355
  )
356
  st.divider()
357
  tip = "If you upload a PDF, it will be sent to LLM."
358
- pdf_upload = st.file_uploader("Upload a PDF", type="pdf", key="pdf_uploaded", on_change=add_pdf_file_to_messages, help=)
359
  ###---- Groq Models Sidebar Customization----###
360
  else:
361
  pass # will add later
@@ -424,7 +424,7 @@ else:
424
  valid_content = [
425
  content for content in message["content"]
426
  if not (
427
- (content["type"] == "text" and content["text"] == "Please Answer what is asked in the audio.") or
428
  content["type"] == "pdf_file"
429
  )
430
  ]
@@ -463,7 +463,7 @@ else:
463
  "role": "user",
464
  "content": [{
465
  "type": "text",
466
- "text": "Please Answer what is asked in the audio.",
467
  }]
468
  }
469
  )
 
355
  )
356
  st.divider()
357
  tip = "If you upload a PDF, it will be sent to LLM."
358
+ pdf_upload = st.file_uploader("Upload a PDF", type="pdf", key="pdf_uploaded", on_change=add_pdf_file_to_messages, help=tip)
359
  ###---- Groq Models Sidebar Customization----###
360
  else:
361
  pass # will add later
 
424
  valid_content = [
425
  content for content in message["content"]
426
  if not (
427
+ (content["type"] == "text" and content["text"] == "Please Answer the Question asked in the audio.") or
428
  content["type"] == "pdf_file"
429
  )
430
  ]
 
463
  "role": "user",
464
  "content": [{
465
  "type": "text",
466
+ "text": "Please Answer the Question asked in the audio.",
467
  }]
468
  }
469
  )
test5.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from audio_recorder_streamlit import audio_recorder
3
+ from groq_models import create_groq_agent, groq_chatbot, get_tools, summarizer_model
4
+ from langchain_community.document_loaders import Docx2txtLoader
5
+ from langchain_community.document_loaders import TextLoader
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import base64
9
+ from streamlit_lottie import st_lottie
10
+ import json
11
+ from utils import set_safety_settings, about, extract_all_pages_as_images
12
+ import google.generativeai as genai
13
+ import os, random, validators
14
+ import tempfile
15
+ import asyncio
16
+ import edge_tts
17
+ from dotenv import load_dotenv
18
+ load_dotenv()
19
+
20
+ st.set_page_config(
21
+ page_title="Super GPT",
22
+ page_icon="⚑",
23
+ layout="wide",
24
+ initial_sidebar_state="auto",
25
+ menu_items={"About": about(), "Get Help":"https://www.linkedin.com/in/aasher-kamal-a227a124b/"},
26
+ )
27
+
28
+ ###--- Title ---###
29
+ st.markdown("""
30
+ <h1 style='text-align: center;'>
31
+ <span style='color: #F81F6F;'>Super</span>
32
+ <span style='color: #f5f8fc;'>AI Assistant</span>
33
+ </h1>
34
+ """, unsafe_allow_html=True)
35
+
36
+
37
+ google_models = [
38
+ "gemini-1.5-flash",
39
+ "gemini-1.5-pro",
40
+ ]
41
+
42
+ groq_models = [
43
+ "llama-3.1-8b-instant",
44
+ "llama-3.1-70b-versatile",
45
+ "llama3-70b-8192",
46
+ "llama3-8b-8192",
47
+ "gemma2-9b-it",
48
+ "mixtral-8x7b-32768"
49
+ ]
50
+
51
+ voices = {
52
+ "William":"en-AU-WilliamNeural",
53
+ "James":"en-PH-JamesNeural",
54
+ "Jenny":"en-US-JennyNeural",
55
+ "US Guy":"en-US-GuyNeural",
56
+ "Sawara":"hi-IN-SwaraNeural",
57
+ }
58
+
59
+
60
+ def speech_recoginition():
61
+ pass
62
+
63
+ @st.cache_data
64
+ def load_lottie_file(filepath: str):
65
+ with open(filepath, "r") as f:
66
+ return json.load(f)
67
+
68
+
69
+ async def generate_speech(text, voice):
70
+ communicate = edge_tts.Communicate(text, voice)
71
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
72
+ await communicate.save(temp_file.name)
73
+ temp_file_path = temp_file.name
74
+ return temp_file_path
75
+
76
+
77
+ def get_audio_player(file_path):
78
+ with open(file_path, "rb") as f:
79
+ data = f.read()
80
+ b64 = base64.b64encode(data).decode()
81
+ return f'<audio autoplay="true" src="data:audio/mp3;base64,{b64}">'
82
+
83
+ def get_llm_info(available_models):
84
+ with st.sidebar:
85
+ tip =tip = "Select Gemini models if you require multi-modal capabilities (text, image, audio and video inputs)"
86
+ model = st.selectbox("Choose LLM:", available_models, help=tip)
87
+
88
+ model_type = None
89
+ if model.startswith(("llama", "gemma", "mixtral")): model_type = "groq"
90
+ elif model.startswith("gemini"): model_type = "google"
91
+
92
+ with st.popover("βš™οΈModel Parameters", use_container_width=True):
93
+ temp = st.slider("Temperature:", min_value=0.0,
94
+ max_value=2.0, value=0.5, step=0.5)
95
+
96
+ max_tokens = st.slider("Maximum Tokens:", min_value=100,
97
+ max_value=2000, value=400, step=200)
98
+ return model, model_type, temp, max_tokens
99
+
100
+
101
+ ###--- Function to convert base64 to temp file ---###
102
+ def base64_to_temp_file(base64_string, unique_name, file_extension):
103
+ base64_string = base64_string.split(",")[1]
104
+ file_bytes = BytesIO(base64.b64decode(base64_string))
105
+ temp_file_path = f"{unique_name}.{file_extension}"
106
+ with open(temp_file_path, "wb") as temp_file:
107
+ temp_file.write(file_bytes.read())
108
+ return temp_file_path
109
+
110
+
111
+ def messages_to_gemini(messages):
112
+ gemini_messages = []
113
+ prev_role = None
114
+ uploaded_files = set([file.display_name.split(".")[0] for file in genai.list_files()])
115
+
116
+ for message in messages:
117
+ if prev_role and (prev_role == message["role"]):
118
+ gemini_message = gemini_messages[-1]
119
+ else:
120
+ gemini_message = {
121
+ "role": "model" if message["role"] == "assistant" else "user",
122
+ "parts": [],
123
+ }
124
+
125
+ for content in message["content"]:
126
+ if content["type"] == "text":
127
+ gemini_message["parts"].append(content["text"])
128
+
129
+ elif content["type"] == "image_url":
130
+ gemini_message["parts"].append(base64_to_image(content["image_url"]["url"]))
131
+
132
+ # elif content["type"] == "video_file":
133
+ # file_path = content["video_file"]
134
+ # if file_path.split(".")[0] not in uploaded_files:
135
+ # with st.spinner(f"Sending video to Gemini..."):
136
+ # gemini_message["parts"].append(genai.upload_file(path=file_path))
137
+
138
+ elif content["type"] in ["video_file", "audio_file", "speech_input"]:
139
+ file_name = content['unique_name']
140
+
141
+ if file_name not in uploaded_files:
142
+ temp_file_path = base64_to_temp_file(content[content["type"]], file_name, "mp4" if content["type"] == "video_file" else "wav")
143
+
144
+ with st.spinner(f"Sending {content['type'].replace('_', ' ')} to Gemini..."):
145
+ gemini_message["parts"].append(genai.upload_file(path=temp_file_path))
146
+ os.remove(temp_file_path)
147
+
148
+ elif content["type"] == "pdf_file":
149
+ if content['pdf_file'].split(".")[0] not in uploaded_files:
150
+ with st.spinner("Sending your PDF to Gemini..."):
151
+ gemini_message["parts"].append(genai.upload_file(path=content['pdf_file']))
152
+ os.remove(content['pdf_file'])
153
+
154
+ if prev_role != message["role"]:
155
+ gemini_messages.append(gemini_message)
156
+
157
+ prev_role = message["role"]
158
+
159
+ return gemini_messages
160
+
161
+
162
+ ##-- Converting base64 to image ---##
163
+ def base64_to_image(base64_string):
164
+ base64_string = base64_string.split(",")[1]
165
+
166
+ return Image.open(BytesIO(base64.b64decode(base64_string)))
167
+
168
+ def add_pdf_file_to_messages():
169
+ if st.session_state.pdf_uploaded:
170
+ # Save the PDF file
171
+ pdf_id = random.randint(1000, 9999)
172
+ pdf_filename = f"pdf_{pdf_id}.pdf"
173
+ with open(pdf_filename, "wb") as f:
174
+ f.write(st.session_state.pdf_uploaded.read())
175
+
176
+ # Add the PDF file to session_state messages
177
+ st.session_state.messages.append(
178
+ {
179
+ "role": "user",
180
+ "content": [{
181
+ "type": "pdf_file",
182
+ "pdf_file": pdf_filename,
183
+ }]
184
+ }
185
+ )
186
+
187
+ def save_uploaded_video(video_file, file_path):
188
+ with open(file_path, "wb") as f:
189
+ f.write(video_file.read())
190
+
191
+ ##--- Function for adding media files to session_state messages ---###
192
+ def add_media_files_to_messages():
193
+ if st.session_state.uploaded_file:
194
+ file_type = st.session_state.uploaded_file.type
195
+ file_content = st.session_state.uploaded_file.getvalue()
196
+
197
+ if file_type.startswith("image"):
198
+ img = base64.b64encode(file_content).decode()
199
+ st.session_state.messages.append(
200
+ {
201
+ "role": "user",
202
+ "content": [{
203
+ "type": "image_url",
204
+ "image_url": {"url": f"data:{file_type};base64,{img}"}
205
+ }]
206
+ }
207
+ )
208
+ elif file_type == "video/mp4":
209
+ video_base64 = base64.b64encode(file_content).decode()
210
+ unique_id = random.randint(1000, 9999)
211
+ # file_name = st.session_state.uploaded_file.name
212
+ # file_path = os.path.join(tempfile.gettempdir(), file_name)
213
+ # save_uploaded_video(st.session_state.uploaded_file, file_path)
214
+
215
+ st.session_state.messages.append(
216
+ {
217
+ "role": "user",
218
+ "content": [{
219
+ "type": "video_file",
220
+ "video_file": f"data:{file_type};base64,{video_base64}",
221
+ "unique_name": f"temp_{unique_id}"
222
+ }]
223
+ }
224
+ )
225
+ elif file_type.startswith("audio"):
226
+ audio_base64 = base64.b64encode(file_content).decode()
227
+ unique_id = random.randint(1000, 9999)
228
+ st.session_state.messages.append(
229
+ {
230
+ "role": "user",
231
+ "content": [{
232
+ "type": "audio_file",
233
+ "audio_file": f"data:{file_type};base64,{audio_base64}",
234
+ "unique_name": f"temp_{unique_id}"
235
+ }]
236
+ }
237
+ )
238
+
239
+ ###--- FUNCTION TO ADD CAMERA IMAGE TO MESSAGES ---##
240
+ def add_camera_img_to_messages():
241
+ if "camera_img" in st.session_state and st.session_state.camera_img:
242
+ img = base64.b64encode(st.session_state.camera_img.getvalue()).decode()
243
+ st.session_state.messages.append(
244
+ {
245
+ "role": "user",
246
+ "content": [{
247
+ "type": "image_url",
248
+ "image_url": {"url": f"data:image/jpeg;base64,{img}"}
249
+ }]
250
+ }
251
+ )
252
+
253
+ ##--- FUNCTION TO RESET CONVERSATION ---##
254
+ def reset_conversation():
255
+ if "messages" in st.session_state and len(st.session_state.messages) > 0:
256
+ st.session_state.pop("messages", None)
257
+ if "groq_chat_history" in st.session_state and len(st.session_state.groq_chat_history) > 1:
258
+ st.session_state.pop("groq_chat_history", None)
259
+
260
+ for file in genai.list_files():
261
+ genai.delete_file(file.name)
262
+
263
+ # Reset the uploaded files list
264
+ if "uploaded_files" in st.session_state:
265
+ st.session_state.pop("uploaded_files", None)
266
+
267
+ if "pdf_uploaded" in st.session_state:
268
+ st.session_state.pop("pdf_uploaded", None)
269
+
270
+ ##--- FUNCTION TO STREAM GEMINI RESPONSE ---##
271
+ def stream_gemini_response(model_params, api_key):
272
+ response_message = ""
273
+
274
+ genai.configure(api_key=api_key)
275
+ model = genai.GenerativeModel(
276
+ model_name = model_params["model"],
277
+ generation_config={
278
+ "temperature": model_params["temperature"],
279
+ "max_output_tokens": model_params["max_tokens"],
280
+ },
281
+ safety_settings=set_safety_settings(),
282
+ system_instruction="""You are a helpful assistant who asnwers user's questions professionally and politely."""
283
+ )
284
+ gemini_messages = messages_to_gemini(st.session_state.messages)
285
+
286
+ for chunk in model.generate_content(contents=gemini_messages, stream=True):
287
+ chunk_text = chunk.text or ""
288
+ response_message += chunk_text
289
+ yield chunk_text
290
+
291
+ st.session_state.messages.append({
292
+ "role": "assistant",
293
+ "content": [
294
+ {
295
+ "type": "text",
296
+ "text": response_message,
297
+ }
298
+ ]})
299
+
300
+ if "summarize" not in st.session_state:
301
+ st.session_state.summarize = False
302
+ ##--- API KEYS ---##
303
+ with st.sidebar:
304
+ st.logo("logo.png")
305
+ api_cols = st.columns(2)
306
+ with api_cols[0]:
307
+ with st.popover("πŸ” Groq", use_container_width=True):
308
+ groq_api_key = st.text_input("Click [here](https://console.groq.com/keys) to get your Groq API key", value=os.getenv("GROQ_API_KEY") , type="password")
309
+
310
+ with api_cols[1]:
311
+ with st.popover("πŸ” Google", use_container_width=True):
312
+ google_api_key = st.text_input("Click [here](https://aistudio.google.com/app/apikey) to get your Google API key", value=os.getenv("GOOGLE_API_KEY") , type="password")
313
+
314
+ ##--- API KEY CHECK ---##
315
+ if (groq_api_key == "" or groq_api_key is None or "gsk" not in groq_api_key) and (google_api_key == "" or google_api_key is None or "AIza" not in google_api_key):
316
+ st.warning("Please Add an API Key to proceed.")
317
+
318
+ ####--- LLM SIDEBAR ---###
319
+ else:
320
+ with st.sidebar:
321
+ st.divider()
322
+ columns = st.columns(2)
323
+ # animation
324
+ with columns[0]:
325
+ lottie_animation = load_lottie_file("animation.json")
326
+ if lottie_animation:
327
+ st_lottie(lottie_animation, height=100, width=100, quality="high", key="lottie_anim")
328
+
329
+ with columns[1]:
330
+ if st.toggle("Voice Response"):
331
+ response_voice = st.selectbox("Available Voices:", options=voices.keys(), key="voice_response")
332
+
333
+ available_models = [] + (google_models if google_api_key else []) + (groq_models if groq_api_key else [])
334
+ model, model_type, temperature, max_tokens = get_llm_info(available_models)
335
+
336
+ model_params = {
337
+ "model": model,
338
+ "temperature": temperature,
339
+ "max_tokens": max_tokens
340
+ }
341
+ st.divider()
342
+
343
+ ###---- Google Gemini Sidebar Customization----###
344
+ if model_type == "google":
345
+ st.write("Upload a file or take a picture")
346
+
347
+ media_cols = st.columns(2)
348
+
349
+ with media_cols[0]:
350
+ with st.popover("πŸ“ Upload", use_container_width=True):
351
+ st.file_uploader(
352
+ "Upload an image, audio or a video",
353
+ type=["png", "jpg", "jpeg", "wav", "mp3", "mp4"],
354
+ accept_multiple_files=False,
355
+ key="uploaded_file",
356
+ on_change=add_media_files_to_messages,
357
+ )
358
+
359
+ with media_cols[1]:
360
+ with st.popover("πŸ“· Camera", use_container_width=True):
361
+ activate_camera = st.checkbox("Activate camera")
362
+ if activate_camera:
363
+ st.camera_input(
364
+ "Take a picture",
365
+ key="camera_img",
366
+ on_change=add_camera_img_to_messages,
367
+ )
368
+ st.divider()
369
+ tip = "If you upload a PDF, it will be sent to LLM."
370
+ pdf_upload = st.file_uploader("Upload a PDF", type="pdf", key="pdf_uploaded", on_change=add_pdf_file_to_messages, help=tip)
371
+ ###---- Groq Models Sidebar Customization----###
372
+ else:
373
+ groq_llm_type = st.radio(label="Select the LLM type:", key="groq_llm_type",options=["Agent", "Chatbot", "Summarizer"], horizontal=True)
374
+ if groq_llm_type == "Summarizer":
375
+ url = st.text_input("Enter YT video or Webpage URL:", key="url_to_summarize",
376
+ help="Only Youtube videos having captions can be summarized.")
377
+
378
+ summarize_button = st.button("Summarize", type="primary", use_container_width=True, key="summarize")
379
+
380
+
381
+ ######----- Main Interface -----#######
382
+ chat_col1, chat_col2 = st.columns([1,3.5])
383
+
384
+ with chat_col1:
385
+ ###--- Audio Recording ---###
386
+ audio_bytes = audio_recorder("Speak",
387
+ neutral_color="#f5f8fc",
388
+ recording_color="#f81f6f",
389
+ icon_name="microphone-lines",
390
+ icon_size="3x")
391
+
392
+ ###--- Reset Conversation ---###
393
+ st.button(
394
+ "πŸ—‘ Reset",
395
+ use_container_width=True,
396
+ on_click=reset_conversation,
397
+ help="If clicked, conversation will be reset.",
398
+ )
399
+ if "pdf_uploaded" not in st.session_state:
400
+ st.session_state.pdf_uploaded = None
401
+
402
+ if st.session_state.pdf_uploaded:
403
+ pdf_pages = extract_all_pages_as_images(st.session_state.pdf_uploaded)
404
+ st.session_state["pdf_pages"] = pdf_pages
405
+ zoom_level = st.slider(label="",label_visibility="collapsed",
406
+ min_value=100, max_value=1000, value=400, step=100, key="zoom_level"
407
+ )
408
+ with st.container(height=200, border=True):
409
+ for page_image in pdf_pages:
410
+ st.image(page_image, width=zoom_level)
411
+
412
+ if "messages" not in st.session_state:
413
+ st.session_state.messages = []
414
+ if "uploaded_files" not in st.session_state:
415
+ st.session_state.uploaded_files = []
416
+ if "groq_chat_history" not in st.session_state:
417
+ st.session_state.groq_chat_history = []
418
+
419
+ # Handle speech input
420
+ speech_file_added = False
421
+ if "prev_speech_hash" not in st.session_state:
422
+ st.session_state.prev_speech_hash = None
423
+
424
+ if audio_bytes and st.session_state.prev_speech_hash != hash(audio_bytes):
425
+ st.session_state.prev_speech_hash = hash(audio_bytes)
426
+ speech_base64 = base64.b64encode(audio_bytes).decode()
427
+ unique_id = random.randint(1000, 9999)
428
+ st.session_state.messages.append(
429
+ {
430
+ "role": "user",
431
+ "content": [{
432
+ "type": "speech_input",
433
+ "speech_input": f"data:audio/wav;base64,{speech_base64}",
434
+ "unique_name": f"temp_{unique_id}"
435
+ }]
436
+ }
437
+ )
438
+ speech_file_added = True
439
+
440
+
441
+ with chat_col2:
442
+ message_container = st.container(height=400, border=False)
443
+
444
+ for message in st.session_state.messages:
445
+ avatar = "assistant.png" if message["role"] == "assistant" else "user.png"
446
+ valid_content = [
447
+ content for content in message["content"]
448
+ if not (
449
+ (content["type"] == "text" and content["text"] == "Please Answer the Question asked in the audio.") or
450
+ content["type"] == "pdf_file"
451
+ )
452
+ ]
453
+ if valid_content:
454
+ with message_container.chat_message(message["role"], avatar=avatar):
455
+ for content in message["content"]:
456
+ if content["type"] == "text":
457
+ st.markdown(content["text"])
458
+ elif content["type"] == "image_url":
459
+ st.image(content["image_url"]["url"])
460
+ elif content["type"] == "video_file":
461
+ st.video(content["video_file"])
462
+ elif content["type"] == "audio_file":
463
+ st.audio(content["audio_file"], autoplay=True)
464
+ elif content["type"] == "speech_input":
465
+ st.audio(content["speech_input"])
466
+
467
+ for msg in st.session_state.groq_chat_history:
468
+ avatar = "assistant.png" if msg["role"] == "assistant" else "user.png"
469
+ with message_container.chat_message(msg["role"], avatar=avatar):
470
+ st.markdown(msg['content'])
471
+
472
+
473
+ ###----- User Question -----###
474
+ if prompt:= st.chat_input("Type you question", key="question") or speech_file_added or st.session_state.summarize:
475
+ if model_type == "groq":
476
+
477
+ if not speech_file_added and not st.session_state.summarize:
478
+ message_container.chat_message("user", avatar="user.png").markdown(prompt)
479
+ st.session_state.groq_chat_history.append({"role": "user", "content": prompt})
480
+ elif speech_file_added:
481
+ speech_to_text = speech_recoginition()
482
+ st.session_state.groq_chat_history.append({"role": "user", "content": speech_to_text})
483
+
484
+ with message_container.chat_message("assistant", avatar="assistant.png"):
485
+
486
+ try:
487
+ if groq_llm_type == "Chatbot":
488
+ final_response = st.write_stream(groq_chatbot(model_params=model_params, api_key=groq_api_key,
489
+ question=prompt, chat_history=st.session_state.groq_chat_history))
490
+
491
+ elif groq_llm_type == "Agent":
492
+ final_response = create_groq_agent(model_params=model_params, api_key=groq_api_key,
493
+ question=prompt,
494
+ tools=get_tools(),
495
+ chat_history=st.session_state.groq_chat_history,)
496
+
497
+ st.markdown(final_response)
498
+
499
+ elif groq_llm_type == "Summarizer":
500
+ if not url.strip():
501
+ st.error("Please enter a URL")
502
+ elif not validators.url(url):
503
+ st.error("Please enter a valid URL")
504
+ else:
505
+ with st.spinner("Summarizing..."):
506
+ final_response = summarizer_model(model_params=model_params, api_key=groq_api_key, url=url)
507
+ st.markdown(final_response)
508
+
509
+ st.session_state.groq_chat_history.append({"role": "assistant", "content": final_response})
510
+
511
+ if "voice_response" in st.session_state and st.session_state.voice_response:
512
+ response_voice = st.session_state.voice_response
513
+ text_to_speak = (final_response).translate(str.maketrans('', '', '#-*_πŸ˜ŠπŸ‘‹πŸ˜„πŸ˜πŸ₯³πŸ‘πŸ€©πŸ˜‚πŸ˜Ž')) # Removing special chars and emojis
514
+ with st.spinner("Generating voice response..."):
515
+ temp_file_path = asyncio.run(generate_speech(text_to_speak, voices[response_voice]))
516
+ audio_player_html = get_audio_player(temp_file_path) # Create an audio player
517
+ st.markdown(audio_player_html, unsafe_allow_html=True)
518
+ os.unlink(temp_file_path) # Clean up the temporary audio file
519
+
520
+ except Exception as e:
521
+ st.error(f"An error occurred: {e}", icon="❌")
522
+
523
+ else:
524
+ if not speech_file_added:
525
+ message_container.chat_message("user", avatar="user.png").markdown(prompt)
526
+
527
+ st.session_state.messages.append(
528
+ {
529
+ "role": "user",
530
+ "content": [{
531
+ "type": "text",
532
+ "text": prompt,
533
+ }]
534
+ }
535
+ )
536
+ ###----Google Gemini Response----###
537
+ else:
538
+ st.session_state.messages.append(
539
+ {
540
+ "role": "user",
541
+ "content": [{
542
+ "type": "text",
543
+ "text": "Please Answer the Question asked in the audio.",
544
+ }]
545
+ }
546
+ )
547
+
548
+ ###----- Generate response -----###
549
+ with message_container.chat_message("assistant", avatar="assistant.png"):
550
+ try:
551
+ final_response = st.write_stream(stream_gemini_response(model_params=model_params, api_key= google_api_key))
552
+
553
+ if "voice_response" in st.session_state and st.session_state.voice_response:
554
+ response_voice = st.session_state.voice_response
555
+ text_to_speak = (final_response).translate(str.maketrans('', '', '#-*_πŸ˜ŠπŸ‘‹πŸ˜„πŸ˜πŸ₯³πŸ‘πŸ€©πŸ˜‚πŸ˜Ž')) # Removing special chars and emojis
556
+ with st.spinner("Generating voice response..."):
557
+ temp_file_path = asyncio.run(generate_speech(text_to_speak, voices[response_voice]))
558
+ audio_player_html = get_audio_player(temp_file_path) # Create an audio player
559
+ st.markdown(audio_player_html, unsafe_allow_html=True)
560
+ os.unlink(temp_file_path) # Clean up the temporary audio file
561
+
562
+ except Exception as e:
563
+ st.error(f"An error occurred: {e}", icon="❌")
564
+
565
+
utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from streamlit_vertical_slider import vertical_slider
3
  import pdfplumber
 
4
 
5
  @st.dialog("Confirm Selection πŸ‘‡", width="large")
6
  def visualize_display_page(selection_dict):
 
1
  import streamlit as st
2
  from streamlit_vertical_slider import vertical_slider
3
  import pdfplumber
4
+ from langchain_core.prompts import ChatPromptTemplate
5
 
6
  @st.dialog("Confirm Selection πŸ‘‡", width="large")
7
  def visualize_display_page(selection_dict):