Spaces:
Build error
Build error
File size: 4,847 Bytes
5bea3fb ec96023 c796379 bda4e9b 5bea3fb ec96023 5bea3fb ec96023 efb8ba7 5bea3fb bda4e9b 5bea3fb efb8ba7 5bea3fb c796379 5bea3fb efb8ba7 c796379 b0cfa99 5bea3fb b0cfa99 5bea3fb b0cfa99 5bea3fb b0cfa99 5bea3fb 411b037 c796379 5bea3fb c796379 5bea3fb c796379 7df1f40 5bea3fb 7df1f40 5bea3fb c796379 7df1f40 5bea3fb 7df1f40 ec96023 5bea3fb ec96023 411b037 5bea3fb 411b037 5bea3fb 411b037 5bea3fb c796379 5bea3fb c796379 5bea3fb c796379 5bea3fb c796379 5bea3fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from dotenv import load_dotenv
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
import streamlit as st
def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
return SQLDatabase.from_uri(db_uri)
def get_sql_chain(db):
template = """
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
Write only the SQL query and nothing else. Do not wrap the SQL query in any other text, not even backticks.
For example:
Question: which 3 artists have the most tracks?
SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
Question: Name 10 artists
SQL Query: SELECT Name FROM Artist LIMIT 10;
Your turn:
Question: {question}
SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)
# llm = ChatOpenAI(model="gpt-4-0125-preview")
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
def get_schema(_):
return db.get_table_info()
return (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm
| StrOutputParser()
)
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
sql_chain = get_sql_chain(db)
template = """
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
Based on the table schema below, question, sql query, and sql response, write a natural language response.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
SQL Query: <SQL>{query}</SQL>
User question: {question}
SQL Response: {response}"""
prompt = ChatPromptTemplate.from_template(template)
# llm = ChatOpenAI(model="gpt-4-0125-preview")
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
chain = (
RunnablePassthrough.assign(query=sql_chain).assign(
schema=lambda _: db.get_table_info(),
response=lambda vars: db.run(vars["query"]),
)
| prompt
| llm
| StrOutputParser()
)
return chain.invoke({
"question": user_query,
"chat_history": chat_history,
})
if "chat_history" not in st.session_state:
st.session_state.chat_history = [
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
]
load_dotenv()
st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:")
st.title("Chat with MySQL")
with st.sidebar:
st.subheader("Settings")
st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
st.text_input("Host", value="localhost", key="Host")
st.text_input("Port", value="3306", key="Port")
st.text_input("User", value="root", key="User")
st.text_input("Password", type="password", value="admin", key="Password")
st.text_input("Database", value="Chinook", key="Database")
if st.button("Connect"):
with st.spinner("Connecting to database..."):
db = init_database(
st.session_state["User"],
st.session_state["Password"],
st.session_state["Host"],
st.session_state["Port"],
st.session_state["Database"]
)
st.session_state.db = db
st.success("Connected to database!")
for message in st.session_state.chat_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"):
st.markdown(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.markdown(message.content)
user_query = st.chat_input("Type a message...")
if user_query is not None and user_query.strip() != "":
st.session_state.chat_history.append(HumanMessage(content=user_query))
with st.chat_message("Human"):
st.markdown(user_query)
with st.chat_message("AI"):
response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
st.markdown(response)
st.session_state.chat_history.append(AIMessage(content=response)) |