Coolui / app.py
Wazzever's picture
Update app.py
c013b9f verified
import streamlit as st
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List, Optional
from langchain_core.prompts import ChatPromptTemplate
import json
from langchain_chroma import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.documents import Document
import os
from datetime import datetime
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableParallel
from langchain_groq import ChatGroq
from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.messages import HumanMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.documents import Document
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.utilities import SQLDatabase
import pytz
import sqlite3
import tqdm
# Streamlit app layout
st.title("🦜🔗 Reminder AI")
tagging_prompt = ChatPromptTemplate.from_template(
"""
Extract the desired information from the following passage.
Only extract the properties mentioned in the 'Classification' function.
Training data for reference:
{traning_data}
Passage:
{input}
"""
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an expert extraction algorithm. "
"Only extract relevant information from the text. "
"If you do not know the value of an attribute asked to extract, "
"return null for the attribute's value.",
),
("human", "{text}"),
]
)
os.environ["GROQ_API_KEY"] = "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"
class Classification(BaseModel):
sentiment: str = Field(..., enum=["set reminder", "update reminder", "check reminder", "remove reminder", "other content"])
llm = ChatGroq(model="llama-3.1-70b-versatile").with_structured_output(Classification)
llm_st = ChatGroq(model="llama-3.1-70b-versatile", temperature=0)
llm_chat = ChatGroq(model="llama-3.1-70b-versatile")
class SetReminder(BaseModel):
reason: str = Field(..., description="Reason of the reminder")
time: str = Field(..., description="What time is the reminder set for?")
date: str = Field(..., description="What date is the reminder set for?")
tagging_chain = tagging_prompt | llm
runnable = prompt | llm_st.with_structured_output(schema=SetReminder)
def get_current_time():
adelaide_tz = pytz.timezone('Australia/Adelaide')
now = datetime.now(adelaide_tz)
return now.strftime("%I:%M:%S %p") # Returns time in HH:MM:SS format
def get_date():
adelaide_tz = pytz.timezone('Australia/Adelaide')
current_date = datetime.now(adelaide_tz).strftime("%Y-%m-%d")
return current_date
def create_reminders_table():
sql_statement = """
CREATE TABLE IF NOT EXISTS reminders (
id INTEGER PRIMARY KEY,
time TEXT NOT NULL,
date TEXT NOT NULL,
reason TEXT NOT NULL
);
"""
try:
with sqlite3.connect('reminders.db') as conn:
cursor = conn.cursor()
cursor.execute(sql_statement)
conn.commit()
print("sucess")
except sqlite3.Error as e:
print(e)
create_reminders_table()
def list_reminders():
sql_statement = "SELECT * FROM reminders;"
try:
with sqlite3.connect('reminders.db') as conn:
cursor = conn.cursor()
cursor.execute(sql_statement)
rows = cursor.fetchall()
for row in rows:
print(row)
except sqlite3.Error as e:
print(f"Error: {e}")
def get_reminders():
sql_statement = "SELECT * FROM reminders;"
try:
with sqlite3.connect('reminders.db') as conn:
cursor = conn.cursor()
cursor.execute(sql_statement)
rows = cursor.fetchall()
return rows
except sqlite3.Error as e:
print(f"Error: {e}")
return []
def format_reminders_for_context(reminders):
context = ""
for reminder in reminders:
context += f"ID: {reminder[0]}, Time: {reminder[1]}, Date: {reminder[2]}, Reason: {reminder[3]}\n"
return context
def get_session_history(user_id: str, conversation_id: str):
return SQLChatMessageHistory(f"{user_id}--{conversation_id}", "sqlite:///memory.db")
chatting_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Your are a friendly assistant Ai! Today date is {today_date} and Time now is {time_now}",
),
MessagesPlaceholder(variable_name="history"),
("human", "{input}"),
]
)
runnable_chat = chatting_prompt | llm_chat
with_message_history = RunnableWithMessageHistory(
runnable_chat,
get_session_history,
input_messages_key="input",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID", #user_id
description="Unique identifier for the user.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="conversation_id",
annotation=str,
name="Conversation ID", #session
description="Unique identifier for the conversation.",
default="",
is_shared=True,
),
],
)
def get_ai_response(user_input):
response = ""
try:
for s in with_message_history.stream(
{"input": user_input, "time_now": get_current_time(), "today_date": get_date()},
config={"user_id": "123", "conversation_id": "1"}
):
response += s.content
except Exception as e:
error_message = f"Error: {e}"
return response, error_message
return response, None
message_max = """
Answer this question using the database provided only.
{question}
This is today's date: {date_tt}. So if user says "today," it means the date of today!
Database sqlite:
{context}
If no reminder is set on the day user requested, reply with: "No reminder found on that day."
If a reminder is found:
- Directly reply with the specific action and time, e.g., "You need to sleep at 10pm tomorrow."
Avoid providing additional context or listing all reminders.
"""
prompt_max = ChatPromptTemplate.from_messages([("human", message_max)])
rag_chain = prompt_max | llm_chat
message_sec = """
Answer this question using the database provided only.
{question}
This is today's date: {date_tt}. So if user says "today," it means the date of today!
Database sqlite:
{context}
Search about the reminders user requesting for a change, update. If reminders found then reply in this format:
UPDATE reminders
SET date = new date here, time = new time here, reason = new reason here
WHERE date = old date here AND time = old time here AND reason = old reason here;
Example:
SET date = "22/08/2024", time = "10 AM", reason = "Playing phone"
WHERE date = "23/06/2024" AND time = "11 AM" AND reason = "Playing phone";
If no reminder found reply with reminder_x.
Avoid providing additional context or listing all reminders.
"""
prompt_one = ChatPromptTemplate.from_messages([("human", message_sec)])
rag_chaining = prompt_one | llm_chat
message_thr = """
Answer this question using the database provided only.
{question}
This is today's date: {date_tt}. So if user says "today," it means the date of today!
Current time: {current_time}
Database sqlite:
{context}
If no date speicify assume is today, if user say one hour later or other, adding it into the current time to get the time (for example: Current time: 10pm, I said 1 hour later then 10pm + 1 hour is 11pm)
Don't say anything extra, if I am asking to remind something then only reply in this format:
INSERT INTO reminders (date, time, reason)
VALUES (date, time, reason);
Example: (remember to use "" for values to avoid error)
INSERT INTO reminders (date, time, reason)
VALUES ("2024-08-23", "8:00 AM", "Homework");
If there already have the reminder with same date, same time, same reason then reply with repeated_reminder.
Avoid providing additional context or listing all reminders.
"""
prompt_tt = ChatPromptTemplate.from_messages([("human", message_thr)])
rag_chain_tt = prompt_tt | llm_chat
# Load the check_reminder data
with open("important_classification_data.json", 'r') as f:
important_classification_data = json.load(f)
# Convert the intent classification data into Document format
docs = [Document(page_content=data["input"], metadata={"intent": data["intent"]}) for data in important_classification_data]
# Split the documents into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
vectorstore = Chroma.from_documents(documents=splits, embedding=HuggingFaceEmbeddings())
# Set up the retriever to fetch relevant phrases based on the user's query
retriever = vectorstore.as_retriever()
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if user_input := st.chat_input("What is up?"):
try:
time_now = get_current_time()
current_date = get_date()
# Display user message in chat message container
st.chat_message("user").markdown(user_input)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": user_input})
sentimenttocheck = ""
res = tagging_chain().invoke({"input": user_input, "traning_data": retriever})
sentimenttocheck += res.sentiment
if sentimenttocheck == "set reminder":
reminders_thr = get_reminders()
database_thr = format_reminders_for_context(reminders_thr)
new_remind = ""
for new_reminder in rag_chain_tt().stream({"question": user_input, "context": database_thr, "date_tt": get_date(), "current_time": get_current_time()}):
new_remind += new_reminder.content
print(f"New reminder: {new_remind}")
if new_remind == "repeated_reminder":
response_st = "Reminder already existed!"
else:
try:
with sqlite3.connect('reminders.db') as conn:
cursor = conn.cursor()
cursor.execute(new_remind)
conn.commit()
response_st = "New reminder created." if cursor.rowcount > 0 else "No reminder created, Errors occur"
except sqlite3.Error as e:
response_st = f"An error occurred while updating the reminder: {e}"
elif sentimenttocheck == "update reminder":
reminders_sec = get_reminders()
database_sec = format_reminders_for_context(reminders_sec)
updated_remind = ""
for updated_cont in rag_chaining().stream({"question": user_input, "context": database_sec, "date_tt": get_date()}):
updated_remind += updated_cont.content
print(f"Database: {updated_remind}")
if updated_remind == "reminder_x":
response_st = "No reminder found to change!"
else:
try:
with sqlite3.connect('reminders.db') as conn:
cursor = conn.cursor()
cursor.execute(updated_remind)
conn.commit()
response_st = "Reminder updated successfully." if cursor.rowcount > 0 else "No reminder found to update with the given details."
except sqlite3.Error as e:
response_st = f"An error occurred while updating the reminder: {e}"
elif sentimenttocheck == "check reminder":
reminders = get_reminders()
database = format_reminders_for_context(reminders)
response_max = ""
for max in rag_chain_tt().stream({"question": user_input, "context": database, "date_tt": get_date()}):
response_max += max.content
st.session_state.messages.append(f"Database remind: {response_max}")
elif sentimenttocheck == "remove reminder":
st.session_state.messages.append("Remove Reminder")
else:
response_st = get_ai_response(user_input)
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response_st)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response_st})
except Exception as e:
# Print the error and continue
st.session_state.messages.append(f"\nAn error occurred: {e}. Please try again.")