Wazzever commited on
Commit
2961d44
·
verified ·
1 Parent(s): 4a0a5ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -92
app.py CHANGED
@@ -1,36 +1,64 @@
1
  import streamlit as st
2
- from langchain_core.prompts import ChatPromptTemplate
3
- from langchain_core.pydantic_v1 import BaseModel, Field
4
- from typing import List, Optional
5
- from langchain_core.prompts import ChatPromptTemplate
6
- import warnings
7
- import json
8
- from langchain_chroma import Chroma
9
- from langchain_community.embeddings import HuggingFaceEmbeddings
10
- from langchain_text_splitters import RecursiveCharacterTextSplitter
11
- from langchain_core.prompts import ChatPromptTemplate
12
- from langchain_core.output_parsers import StrOutputParser
13
- from langchain_core.runnables import RunnablePassthrough
14
- from langchain_core.documents import Document
15
- import os
16
- from datetime import datetime
17
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
18
- from langchain_core.runnables import RunnableParallel
19
- from langchain_groq import ChatGroq
20
- from langchain_community.chat_message_histories import SQLChatMessageHistory
21
- from langchain_core.runnables import ConfigurableFieldSpec
22
- from langchain_core.messages import HumanMessage
23
- from langchain_core.runnables.history import RunnableWithMessageHistory
24
- from langchain_core.documents import Document
25
- from langchain_core.runnables import RunnablePassthrough
26
- from langchain_huggingface import HuggingFaceEmbeddings
27
- from langchain_community.utilities import SQLDatabase
28
- import pytz
29
- import sqlite3
30
- import tqdm
31
-
32
-
33
- # Database setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def create_reminders_table():
35
  sql_statement = """
36
  CREATE TABLE IF NOT EXISTS reminders (
@@ -40,26 +68,36 @@ def create_reminders_table():
40
  reason TEXT NOT NULL
41
  );
42
  """
 
43
  try:
44
  with sqlite3.connect('reminders.db') as conn:
45
  cursor = conn.cursor()
46
  cursor.execute(sql_statement)
47
  conn.commit()
 
48
  except sqlite3.Error as e:
49
- st.error(f"An error occurred: {e}")
50
 
51
- # Helper functions
52
- def get_current_time():
53
- adelaide_tz = pytz.timezone('Australia/Adelaide')
54
- now = datetime.now(adelaide_tz)
55
- return now.strftime("%I:%M:%S %p") # Returns time in HH:MM:SS format
 
 
 
 
 
 
 
 
 
 
56
 
57
- def get_date():
58
- adelaide_tz = pytz.timezone('Australia/Adelaide')
59
- return datetime.now(adelaide_tz).strftime("%Y-%m-%d")
60
 
61
  def get_reminders():
62
  sql_statement = "SELECT * FROM reminders;"
 
63
  try:
64
  with sqlite3.connect('reminders.db') as conn:
65
  cursor = conn.cursor()
@@ -67,25 +105,70 @@ def get_reminders():
67
  rows = cursor.fetchall()
68
  return rows
69
  except sqlite3.Error as e:
70
- st.error(f"An error occurred: {e}")
71
  return []
72
 
73
-
74
-
75
  def format_reminders_for_context(reminders):
76
  context = ""
77
  for reminder in reminders:
78
  context += f"ID: {reminder[0]}, Time: {reminder[1]}, Date: {reminder[2]}, Reason: {reminder[3]}\n"
79
  return context
80
 
81
- # Define models and prompts
82
- class Classification(BaseModel):
83
- sentiment: str = Field(..., enum=["set reminder", "update reminder", "check reminder", "remove reminder", "other content"])
84
-
85
- llm = ChatGroq(model="llama-3.1-70b-versatile").with_structured_output(Classification)
86
- llm_st = ChatGroq(model="llama-3.1-70b-versatile", temperature=0)
87
- llm_chat = ChatGroq(model="llama-3.1-70b-versatile")
88
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  message_max = """
91
  Answer this question using the database provided only.
@@ -186,84 +269,88 @@ vectorstore = Chroma.from_documents(documents=splits, embedding=HuggingFaceEmbed
186
  # Set up the retriever to fetch relevant phrases based on the user's query
187
  retriever = vectorstore.as_retriever()
188
 
189
- # Define the main function to handle user interaction
 
190
  def handle_user_input(user_input):
191
  time_now = get_current_time()
192
  current_date = get_date()
193
 
194
  sentimenttocheck = ""
195
- try:
196
- res = tagging_chain.invoke({"input": user_input, "traning_data": retriever})
197
- sentimenttocheck = res.sentiment
198
- except Exception as e:
199
- st.error(f"An error occurred: {e}")
200
- return None
201
 
202
  if sentimenttocheck == "set reminder":
203
- reminders = get_reminders()
204
- database = format_reminders_for_context(reminders)
205
  new_remind = ""
206
- for new_reminder in rag_chain_tt.stream({"question": user_input, "context": database, "date_tt": current_date, "current_time": time_now}):
207
  new_remind += new_reminder.content
208
-
209
  if new_remind == "repeated_reminder":
210
- st.info("Reminder already existed!")
211
  else:
212
  try:
213
  with sqlite3.connect('reminders.db') as conn:
214
  cursor = conn.cursor()
 
 
215
  cursor.execute(new_remind)
216
  conn.commit()
 
 
217
  if cursor.rowcount > 0:
218
- st.success("New reminder created.")
219
  else:
220
- st.warning("No reminder created, Errors occurred.")
 
221
  except sqlite3.Error as e:
222
- st.error(f"An error occurred while creating the reminder: {e}")
223
 
224
  elif sentimenttocheck == "update reminder":
225
- reminders = get_reminders()
226
- database = format_reminders_for_context(reminders)
227
- updated_remind = ""
228
- for updated_cont in rag_chaining.stream({"question": user_input, "context": database, "date_tt": current_date}):
229
  updated_remind += updated_cont.content
 
230
  if updated_remind == "reminder_x":
231
- st.info("No reminder found to change!")
232
  else:
233
  try:
234
  with sqlite3.connect('reminders.db') as conn:
235
  cursor = conn.cursor()
 
 
236
  cursor.execute(updated_remind)
237
  conn.commit()
 
 
238
  if cursor.rowcount > 0:
239
- st.success("Reminder updated successfully.")
240
  else:
241
- st.warning("No reminder found to update with the given details.")
 
242
  except sqlite3.Error as e:
243
- st.error(f"An error occurred while updating the reminder: {e}")
244
 
245
  elif sentimenttocheck == "check reminder":
246
  reminders = get_reminders()
247
  database = format_reminders_for_context(reminders)
248
  response_max = ""
249
- for max in rag_chain.stream({"question": user_input, "context": database, "date_tt": current_date}):
250
  response_max += max.content
251
- st.info(f"Reminder details: {response_max}")
252
 
253
  elif sentimenttocheck == "remove reminder":
254
- st.info("Remove Reminder functionality needs to be implemented.")
 
255
  else:
256
  ai_response, error_message = get_ai_response(user_input)
257
- if error_message:
258
- st.error(error_message)
259
- else:
260
- st.info(ai_response)
261
 
262
  # Streamlit app layout
263
  st.title("🦜🔗 Reminder AI")
264
 
265
- create_reminders_table()
266
-
267
  with st.form("reminder_form"):
268
  user_input = st.text_area("Enter your request:")
269
  submit_button = st.form_submit_button("Submit")
@@ -271,11 +358,3 @@ with st.form("reminder_form"):
271
  if submit_button and user_input:
272
  handle_user_input(user_input)
273
 
274
- st.write("Existing reminders:")
275
- reminders_list = get_reminders()
276
- if reminders_list:
277
- for reminder in reminders_list:
278
- st.write(f"ID: {reminder[0]}, Time: {reminder[1]}, Date: {reminder[2]}, Reason: {reminder[3]}")
279
- else:
280
- st.write("No reminders found.")
281
-
 
1
  import streamlit as st
2
+
3
+
4
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module="langchain_community")
5
+
6
+ tagging_prompt = ChatPromptTemplate.from_template(
7
+ """
8
+ Extract the desired information from the following passage.
9
+
10
+ Only extract the properties mentioned in the 'Classification' function.
11
+
12
+ Training data for reference:
13
+ {traning_data}
14
+
15
+ Passage:
16
+ {input}
17
+ """
18
+ )
19
+
20
+ prompt = ChatPromptTemplate.from_messages(
21
+ [
22
+ (
23
+ "system",
24
+ "You are an expert extraction algorithm. "
25
+ "Only extract relevant information from the text. "
26
+ "If you do not know the value of an attribute asked to extract, "
27
+ "return null for the attribute's value.",
28
+ ),
29
+ ("human", "{text}"),
30
+ ]
31
+ )
32
+
33
+ os.environ["GROQ_API_KEY"] = "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"
34
+
35
+ class Classification(BaseModel):
36
+ sentiment: str = Field(..., enum=["set reminder", "update reminder", "check reminder", "remove reminder", "other content"])
37
+
38
+ llm = ChatGroq(model="llama-3.1-70b-versatile").with_structured_output(Classification)
39
+ llm_st = ChatGroq(model="llama-3.1-70b-versatile", temperature=0)
40
+ llm_chat = ChatGroq(model="llama-3.1-70b-versatile")
41
+
42
+ class SetReminder(BaseModel):
43
+ reason: str = Field(..., description="Reason of the reminder")
44
+ time: str = Field(..., description="What time is the reminder set for?")
45
+ date: str = Field(..., description="What date is the reminder set for?")
46
+
47
+ tagging_chain = tagging_prompt | llm
48
+ runnable = prompt | llm_st.with_structured_output(schema=SetReminder)
49
+
50
+
51
+ def get_current_time():
52
+ adelaide_tz = pytz.timezone('Australia/Adelaide')
53
+ now = datetime.now(adelaide_tz)
54
+ return now.strftime("%I:%M:%S %p") # Returns time in HH:MM:SS format
55
+
56
+ def get_date():
57
+ adelaide_tz = pytz.timezone('Australia/Adelaide')
58
+ current_date = datetime.now(adelaide_tz).strftime("%Y-%m-%d")
59
+ return current_date
60
+
61
+
62
  def create_reminders_table():
63
  sql_statement = """
64
  CREATE TABLE IF NOT EXISTS reminders (
 
68
  reason TEXT NOT NULL
69
  );
70
  """
71
+
72
  try:
73
  with sqlite3.connect('reminders.db') as conn:
74
  cursor = conn.cursor()
75
  cursor.execute(sql_statement)
76
  conn.commit()
77
+ print("sucess")
78
  except sqlite3.Error as e:
79
+ print(e)
80
 
81
+ create_reminders_table()
82
+
83
+
84
+ def list_reminders():
85
+ sql_statement = "SELECT * FROM reminders;"
86
+
87
+ try:
88
+ with sqlite3.connect('reminders.db') as conn:
89
+ cursor = conn.cursor()
90
+ cursor.execute(sql_statement)
91
+ rows = cursor.fetchall()
92
+ for row in rows:
93
+ print(row)
94
+ except sqlite3.Error as e:
95
+ print(f"Error: {e}")
96
 
 
 
 
97
 
98
  def get_reminders():
99
  sql_statement = "SELECT * FROM reminders;"
100
+
101
  try:
102
  with sqlite3.connect('reminders.db') as conn:
103
  cursor = conn.cursor()
 
105
  rows = cursor.fetchall()
106
  return rows
107
  except sqlite3.Error as e:
108
+ print(f"Error: {e}")
109
  return []
110
 
 
 
111
  def format_reminders_for_context(reminders):
112
  context = ""
113
  for reminder in reminders:
114
  context += f"ID: {reminder[0]}, Time: {reminder[1]}, Date: {reminder[2]}, Reason: {reminder[3]}\n"
115
  return context
116
 
117
+ def get_session_history(user_id: str, conversation_id: str):
118
+ return SQLChatMessageHistory(f"{user_id}--{conversation_id}", "sqlite:///memory.db")
119
+
120
+ chatting_prompt = ChatPromptTemplate.from_messages(
121
+ [
122
+ (
123
+ "system",
124
+ "Your are a friendly assistant Ai! Today date is {today_date} and Time now is {time_now}",
125
+ ),
126
+ MessagesPlaceholder(variable_name="history"),
127
+ ("human", "{input}"),
128
+ ]
129
+ )
130
+
131
+ runnable_chat = chatting_prompt | llm_chat
132
+
133
+
134
+ with_message_history = RunnableWithMessageHistory(
135
+ runnable_chat,
136
+ get_session_history,
137
+ input_messages_key="input",
138
+ history_messages_key="history",
139
+ history_factory_config=[
140
+ ConfigurableFieldSpec(
141
+ id="user_id",
142
+ annotation=str,
143
+ name="User ID", #user_id
144
+ description="Unique identifier for the user.",
145
+ default="",
146
+ is_shared=True,
147
+ ),
148
+ ConfigurableFieldSpec(
149
+ id="conversation_id",
150
+ annotation=str,
151
+ name="Conversation ID", #session
152
+ description="Unique identifier for the conversation.",
153
+ default="",
154
+ is_shared=True,
155
+ ),
156
+ ],
157
+ )
158
+
159
+
160
+ def get_ai_response(user_input):
161
+ response = ""
162
+ try:
163
+ for s in with_message_history.stream(
164
+ {"input": user_input, "time_now": get_current_time(), "today_date": get_date()},
165
+ config={"user_id": "123", "conversation_id": "1"}
166
+ ):
167
+ response += s.content
168
+ except Exception as e:
169
+ error_message = f"Error: {e}"
170
+ return response, error_message
171
+ return response, None
172
 
173
  message_max = """
174
  Answer this question using the database provided only.
 
269
  # Set up the retriever to fetch relevant phrases based on the user's query
270
  retriever = vectorstore.as_retriever()
271
 
272
+
273
+
274
  def handle_user_input(user_input):
275
  time_now = get_current_time()
276
  current_date = get_date()
277
 
278
  sentimenttocheck = ""
279
+ res = tagging_chain.invoke({"input": user_input, "traning_data": retriever})
280
+ sentimenttocheck += res.sentiment
 
 
 
 
281
 
282
  if sentimenttocheck == "set reminder":
283
+ reminders_thr = get_reminders()
284
+ database_thr = format_reminders_for_context(reminders_thr)
285
  new_remind = ""
286
+ for new_reminder in rag_chain_tt.stream({"question": user_input, "context": database_thr, "date_tt": get_date(), "current_time": time_now}):
287
  new_remind += new_reminder.content
288
+ st.write(f"New reminder: {new_remind}")
289
  if new_remind == "repeated_reminder":
290
+ st.write("Reminder already existed!")
291
  else:
292
  try:
293
  with sqlite3.connect('reminders.db') as conn:
294
  cursor = conn.cursor()
295
+
296
+ # Execute the update query
297
  cursor.execute(new_remind)
298
  conn.commit()
299
+
300
+ # Check if the update was successful
301
  if cursor.rowcount > 0:
302
+ st.write("New reminder created.")
303
  else:
304
+ st.write("No reminder created, errors occurred.")
305
+
306
  except sqlite3.Error as e:
307
+ st.write(f"An error occurred while updating the reminder: {e}")
308
 
309
  elif sentimenttocheck == "update reminder":
310
+ reminders_sec = get_reminders()
311
+ database_sec = format_reminders_for_context(reminders_sec)
312
+ updated_remind = ""
313
+ for updated_cont in rag_chaining.stream({"question": user_input, "context": database_sec, "date_tt": get_date()}):
314
  updated_remind += updated_cont.content
315
+ st.write(f"Database: {updated_remind}")
316
  if updated_remind == "reminder_x":
317
+ st.write("No reminder found to change!")
318
  else:
319
  try:
320
  with sqlite3.connect('reminders.db') as conn:
321
  cursor = conn.cursor()
322
+
323
+ # Execute the update query
324
  cursor.execute(updated_remind)
325
  conn.commit()
326
+
327
+ # Check if the update was successful
328
  if cursor.rowcount > 0:
329
+ st.write("Reminder updated successfully.")
330
  else:
331
+ st.write("No reminder found to update with the given details.")
332
+
333
  except sqlite3.Error as e:
334
+ st.write(f"An error occurred while updating the reminder: {e}")
335
 
336
  elif sentimenttocheck == "check reminder":
337
  reminders = get_reminders()
338
  database = format_reminders_for_context(reminders)
339
  response_max = ""
340
+ for max in rag_chain.stream({"question": user_input, "context": database, "date_tt": get_date()}):
341
  response_max += max.content
342
+ st.write(f"Database remind: {response_max}")
343
 
344
  elif sentimenttocheck == "remove reminder":
345
+ st.write("Remove Reminder")
346
+
347
  else:
348
  ai_response, error_message = get_ai_response(user_input)
349
+ st.write(f"{ai_response}")
 
 
 
350
 
351
  # Streamlit app layout
352
  st.title("🦜🔗 Reminder AI")
353
 
 
 
354
  with st.form("reminder_form"):
355
  user_input = st.text_area("Enter your request:")
356
  submit_button = st.form_submit_button("Submit")
 
358
  if submit_button and user_input:
359
  handle_user_input(user_input)
360