|
|
import streamlit as st |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
|
from typing import List, Optional |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
import json |
|
|
from langchain_chroma import Chroma |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.runnables import RunnablePassthrough |
|
|
from langchain_core.documents import Document |
|
|
import os |
|
|
from datetime import datetime |
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.runnables import RunnableParallel |
|
|
from langchain_groq import ChatGroq |
|
|
from langchain_community.chat_message_histories import SQLChatMessageHistory |
|
|
from langchain_core.runnables import ConfigurableFieldSpec |
|
|
from langchain_core.messages import HumanMessage |
|
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
|
from langchain_core.documents import Document |
|
|
from langchain_core.runnables import RunnablePassthrough |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_community.utilities import SQLDatabase |
|
|
import pytz |
|
|
import sqlite3 |
|
|
import tqdm |
|
|
|
|
|
|
|
|
st.title("🦜🔗 Reminder AI") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tagging_prompt = ChatPromptTemplate.from_template( |
|
|
""" |
|
|
Extract the desired information from the following passage. |
|
|
|
|
|
Only extract the properties mentioned in the 'Classification' function. |
|
|
|
|
|
Training data for reference: |
|
|
{traning_data} |
|
|
|
|
|
Passage: |
|
|
{input} |
|
|
""" |
|
|
) |
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
( |
|
|
"system", |
|
|
"You are an expert extraction algorithm. " |
|
|
"Only extract relevant information from the text. " |
|
|
"If you do not know the value of an attribute asked to extract, " |
|
|
"return null for the attribute's value.", |
|
|
), |
|
|
("human", "{text}"), |
|
|
] |
|
|
) |
|
|
|
|
|
os.environ["GROQ_API_KEY"] = "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" |
|
|
|
|
|
class Classification(BaseModel): |
|
|
sentiment: str = Field(..., enum=["set reminder", "update reminder", "check reminder", "remove reminder", "other content"]) |
|
|
|
|
|
llm = ChatGroq(model="llama-3.1-70b-versatile").with_structured_output(Classification) |
|
|
llm_st = ChatGroq(model="llama-3.1-70b-versatile", temperature=0) |
|
|
llm_chat = ChatGroq(model="llama-3.1-70b-versatile") |
|
|
|
|
|
class SetReminder(BaseModel): |
|
|
reason: str = Field(..., description="Reason of the reminder") |
|
|
time: str = Field(..., description="What time is the reminder set for?") |
|
|
date: str = Field(..., description="What date is the reminder set for?") |
|
|
|
|
|
tagging_chain = tagging_prompt | llm |
|
|
runnable = prompt | llm_st.with_structured_output(schema=SetReminder) |
|
|
|
|
|
|
|
|
def get_current_time(): |
|
|
adelaide_tz = pytz.timezone('Australia/Adelaide') |
|
|
now = datetime.now(adelaide_tz) |
|
|
return now.strftime("%I:%M:%S %p") |
|
|
|
|
|
def get_date(): |
|
|
adelaide_tz = pytz.timezone('Australia/Adelaide') |
|
|
current_date = datetime.now(adelaide_tz).strftime("%Y-%m-%d") |
|
|
return current_date |
|
|
|
|
|
|
|
|
def create_reminders_table(): |
|
|
sql_statement = """ |
|
|
CREATE TABLE IF NOT EXISTS reminders ( |
|
|
id INTEGER PRIMARY KEY, |
|
|
time TEXT NOT NULL, |
|
|
date TEXT NOT NULL, |
|
|
reason TEXT NOT NULL |
|
|
); |
|
|
""" |
|
|
|
|
|
try: |
|
|
with sqlite3.connect('reminders.db') as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(sql_statement) |
|
|
conn.commit() |
|
|
print("sucess") |
|
|
except sqlite3.Error as e: |
|
|
print(e) |
|
|
|
|
|
create_reminders_table() |
|
|
|
|
|
|
|
|
def list_reminders(): |
|
|
sql_statement = "SELECT * FROM reminders;" |
|
|
|
|
|
try: |
|
|
with sqlite3.connect('reminders.db') as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(sql_statement) |
|
|
rows = cursor.fetchall() |
|
|
for row in rows: |
|
|
print(row) |
|
|
except sqlite3.Error as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
|
|
|
def get_reminders(): |
|
|
sql_statement = "SELECT * FROM reminders;" |
|
|
|
|
|
try: |
|
|
with sqlite3.connect('reminders.db') as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(sql_statement) |
|
|
rows = cursor.fetchall() |
|
|
return rows |
|
|
except sqlite3.Error as e: |
|
|
print(f"Error: {e}") |
|
|
return [] |
|
|
|
|
|
def format_reminders_for_context(reminders): |
|
|
context = "" |
|
|
for reminder in reminders: |
|
|
context += f"ID: {reminder[0]}, Time: {reminder[1]}, Date: {reminder[2]}, Reason: {reminder[3]}\n" |
|
|
return context |
|
|
|
|
|
def get_session_history(user_id: str, conversation_id: str): |
|
|
return SQLChatMessageHistory(f"{user_id}--{conversation_id}", "sqlite:///memory.db") |
|
|
|
|
|
chatting_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
( |
|
|
"system", |
|
|
"Your are a friendly assistant Ai! Today date is {today_date} and Time now is {time_now}", |
|
|
), |
|
|
MessagesPlaceholder(variable_name="history"), |
|
|
("human", "{input}"), |
|
|
] |
|
|
) |
|
|
|
|
|
runnable_chat = chatting_prompt | llm_chat |
|
|
|
|
|
|
|
|
with_message_history = RunnableWithMessageHistory( |
|
|
runnable_chat, |
|
|
get_session_history, |
|
|
input_messages_key="input", |
|
|
history_messages_key="history", |
|
|
history_factory_config=[ |
|
|
ConfigurableFieldSpec( |
|
|
id="user_id", |
|
|
annotation=str, |
|
|
name="User ID", |
|
|
description="Unique identifier for the user.", |
|
|
default="", |
|
|
is_shared=True, |
|
|
), |
|
|
ConfigurableFieldSpec( |
|
|
id="conversation_id", |
|
|
annotation=str, |
|
|
name="Conversation ID", |
|
|
description="Unique identifier for the conversation.", |
|
|
default="", |
|
|
is_shared=True, |
|
|
), |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
def get_ai_response(user_input): |
|
|
response = "" |
|
|
try: |
|
|
for s in with_message_history.stream( |
|
|
{"input": user_input, "time_now": get_current_time(), "today_date": get_date()}, |
|
|
config={"user_id": "123", "conversation_id": "1"} |
|
|
): |
|
|
response += s.content |
|
|
except Exception as e: |
|
|
error_message = f"Error: {e}" |
|
|
return response, error_message |
|
|
return response, None |
|
|
|
|
|
message_max = """ |
|
|
Answer this question using the database provided only. |
|
|
|
|
|
{question} |
|
|
|
|
|
This is today's date: {date_tt}. So if user says "today," it means the date of today! |
|
|
|
|
|
Database sqlite: |
|
|
{context} |
|
|
|
|
|
If no reminder is set on the day user requested, reply with: "No reminder found on that day." |
|
|
|
|
|
If a reminder is found: |
|
|
- Directly reply with the specific action and time, e.g., "You need to sleep at 10pm tomorrow." |
|
|
|
|
|
Avoid providing additional context or listing all reminders. |
|
|
""" |
|
|
|
|
|
|
|
|
prompt_max = ChatPromptTemplate.from_messages([("human", message_max)]) |
|
|
|
|
|
rag_chain = prompt_max | llm_chat |
|
|
|
|
|
message_sec = """ |
|
|
Answer this question using the database provided only. |
|
|
|
|
|
{question} |
|
|
|
|
|
This is today's date: {date_tt}. So if user says "today," it means the date of today! |
|
|
|
|
|
Database sqlite: |
|
|
{context} |
|
|
|
|
|
Search about the reminders user requesting for a change, update. If reminders found then reply in this format: |
|
|
|
|
|
UPDATE reminders |
|
|
SET date = new date here, time = new time here, reason = new reason here |
|
|
WHERE date = old date here AND time = old time here AND reason = old reason here; |
|
|
|
|
|
Example: |
|
|
SET date = "22/08/2024", time = "10 AM", reason = "Playing phone" |
|
|
WHERE date = "23/06/2024" AND time = "11 AM" AND reason = "Playing phone"; |
|
|
|
|
|
If no reminder found reply with reminder_x. |
|
|
Avoid providing additional context or listing all reminders. |
|
|
""" |
|
|
|
|
|
prompt_one = ChatPromptTemplate.from_messages([("human", message_sec)]) |
|
|
|
|
|
rag_chaining = prompt_one | llm_chat |
|
|
|
|
|
message_thr = """ |
|
|
Answer this question using the database provided only. |
|
|
|
|
|
{question} |
|
|
|
|
|
This is today's date: {date_tt}. So if user says "today," it means the date of today! |
|
|
|
|
|
Current time: {current_time} |
|
|
|
|
|
Database sqlite: |
|
|
{context} |
|
|
|
|
|
If no date speicify assume is today, if user say one hour later or other, adding it into the current time to get the time (for example: Current time: 10pm, I said 1 hour later then 10pm + 1 hour is 11pm) |
|
|
Don't say anything extra, if I am asking to remind something then only reply in this format: |
|
|
|
|
|
INSERT INTO reminders (date, time, reason) |
|
|
VALUES (date, time, reason); |
|
|
|
|
|
Example: (remember to use "" for values to avoid error) |
|
|
|
|
|
INSERT INTO reminders (date, time, reason) |
|
|
VALUES ("2024-08-23", "8:00 AM", "Homework"); |
|
|
|
|
|
If there already have the reminder with same date, same time, same reason then reply with repeated_reminder. |
|
|
|
|
|
Avoid providing additional context or listing all reminders. |
|
|
""" |
|
|
|
|
|
prompt_tt = ChatPromptTemplate.from_messages([("human", message_thr)]) |
|
|
|
|
|
rag_chain_tt = prompt_tt | llm_chat |
|
|
|
|
|
|
|
|
with open("important_classification_data.json", 'r') as f: |
|
|
important_classification_data = json.load(f) |
|
|
|
|
|
|
|
|
docs = [Document(page_content=data["input"], metadata={"intent": data["intent"]}) for data in important_classification_data] |
|
|
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
|
splits = text_splitter.split_documents(docs) |
|
|
|
|
|
vectorstore = Chroma.from_documents(documents=splits, embedding=HuggingFaceEmbeddings()) |
|
|
|
|
|
|
|
|
retriever = vectorstore.as_retriever() |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if user_input := st.chat_input("What is up?"): |
|
|
try: |
|
|
|
|
|
time_now = get_current_time() |
|
|
current_date = get_date() |
|
|
|
|
|
st.chat_message("user").markdown(user_input) |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
sentimenttocheck = "" |
|
|
res = tagging_chain().invoke({"input": user_input, "traning_data": retriever}) |
|
|
sentimenttocheck += res.sentiment |
|
|
|
|
|
if sentimenttocheck == "set reminder": |
|
|
reminders_thr = get_reminders() |
|
|
database_thr = format_reminders_for_context(reminders_thr) |
|
|
new_remind = "" |
|
|
for new_reminder in rag_chain_tt().stream({"question": user_input, "context": database_thr, "date_tt": get_date(), "current_time": get_current_time()}): |
|
|
new_remind += new_reminder.content |
|
|
print(f"New reminder: {new_remind}") |
|
|
if new_remind == "repeated_reminder": |
|
|
response_st = "Reminder already existed!" |
|
|
else: |
|
|
try: |
|
|
with sqlite3.connect('reminders.db') as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(new_remind) |
|
|
conn.commit() |
|
|
response_st = "New reminder created." if cursor.rowcount > 0 else "No reminder created, Errors occur" |
|
|
except sqlite3.Error as e: |
|
|
response_st = f"An error occurred while updating the reminder: {e}" |
|
|
elif sentimenttocheck == "update reminder": |
|
|
reminders_sec = get_reminders() |
|
|
database_sec = format_reminders_for_context(reminders_sec) |
|
|
updated_remind = "" |
|
|
for updated_cont in rag_chaining().stream({"question": user_input, "context": database_sec, "date_tt": get_date()}): |
|
|
updated_remind += updated_cont.content |
|
|
print(f"Database: {updated_remind}") |
|
|
if updated_remind == "reminder_x": |
|
|
response_st = "No reminder found to change!" |
|
|
else: |
|
|
try: |
|
|
with sqlite3.connect('reminders.db') as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(updated_remind) |
|
|
conn.commit() |
|
|
response_st = "Reminder updated successfully." if cursor.rowcount > 0 else "No reminder found to update with the given details." |
|
|
except sqlite3.Error as e: |
|
|
response_st = f"An error occurred while updating the reminder: {e}" |
|
|
elif sentimenttocheck == "check reminder": |
|
|
reminders = get_reminders() |
|
|
database = format_reminders_for_context(reminders) |
|
|
response_max = "" |
|
|
for max in rag_chain_tt().stream({"question": user_input, "context": database, "date_tt": get_date()}): |
|
|
response_max += max.content |
|
|
st.session_state.messages.append(f"Database remind: {response_max}") |
|
|
elif sentimenttocheck == "remove reminder": |
|
|
st.session_state.messages.append("Remove Reminder") |
|
|
else: |
|
|
response_st = get_ai_response(user_input) |
|
|
|
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
st.markdown(response_st) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response_st}) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
st.session_state.messages.append(f"\nAn error occurred: {e}. Please try again.") |