Spaces:
Build error
Build error
Suresh Beekhani
commited on
Update app.py
Browse files- src/app.py +97 -93
src/app.py
CHANGED
|
@@ -1,40 +1,41 @@
|
|
| 1 |
-
# Import necessary libraries and modules
|
| 2 |
-
from dotenv import load_dotenv
|
| 3 |
-
from langchain_core.messages import AIMessage, HumanMessage #
|
| 4 |
-
from langchain_core.prompts import ChatPromptTemplate #
|
| 5 |
-
from langchain_core.runnables import RunnablePassthrough # To
|
| 6 |
-
from langchain_community.utilities import SQLDatabase # SQL
|
| 7 |
-
from langchain_core.output_parsers import StrOutputParser # To parse outputs
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
import
|
| 11 |
-
import
|
| 12 |
-
|
| 13 |
-
# Load environment variables from the .env file
|
| 14 |
load_dotenv()
|
| 15 |
|
| 16 |
-
# Function to
|
| 17 |
def init_database() -> SQLDatabase:
|
| 18 |
try:
|
| 19 |
-
#
|
| 20 |
-
user = os.getenv("DB_USER", "
|
| 21 |
-
password = os.getenv("DB_PASSWORD", "
|
| 22 |
host = os.getenv("DB_HOST", "localhost")
|
| 23 |
-
port = os.getenv("DB_PORT", "
|
| 24 |
-
database = os.getenv("DB_NAME", "
|
| 25 |
|
| 26 |
-
# Construct the database URI
|
| 27 |
-
db_uri = f"
|
| 28 |
|
| 29 |
-
#
|
| 30 |
return SQLDatabase.from_uri(db_uri)
|
| 31 |
except Exception as e:
|
|
|
|
| 32 |
st.error(f"Failed to connect to database: {e}")
|
| 33 |
return None
|
| 34 |
|
| 35 |
-
# Function to create a chain that generates SQL queries
|
| 36 |
def get_sql_chain(db):
|
| 37 |
-
# SQL
|
| 38 |
template = """
|
| 39 |
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
|
| 40 |
Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
|
|
@@ -47,33 +48,32 @@ def get_sql_chain(db):
|
|
| 47 |
SQL Query:
|
| 48 |
"""
|
| 49 |
|
| 50 |
-
# Create a prompt from the above
|
| 51 |
prompt = ChatPromptTemplate.from_template(template)
|
| 52 |
-
|
| 53 |
-
# Initialize Groq model for generating SQL queries (can switch to OpenAI if needed)
|
| 54 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
| 55 |
|
| 56 |
-
#
|
| 57 |
def get_schema(_):
|
| 58 |
return db.get_table_info()
|
| 59 |
|
| 60 |
-
#
|
| 61 |
-
# 1.
|
| 62 |
-
# 2.
|
| 63 |
-
# 3.
|
| 64 |
return (
|
| 65 |
-
RunnablePassthrough.assign(schema=get_schema) #
|
| 66 |
-
| prompt #
|
| 67 |
-
| llm #
|
| 68 |
-
| StrOutputParser() # Parse the
|
| 69 |
)
|
| 70 |
|
| 71 |
-
# Function to generate a
|
| 72 |
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
| 73 |
-
#
|
| 74 |
sql_chain = get_sql_chain(db)
|
| 75 |
-
|
| 76 |
-
#
|
| 77 |
template = """
|
| 78 |
You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
|
| 79 |
<SCHEMA>{schema}</SCHEMA>
|
|
@@ -83,90 +83,94 @@ def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
|
| 83 |
SQL Response: {response}
|
| 84 |
"""
|
| 85 |
|
| 86 |
-
# Create a
|
| 87 |
prompt = ChatPromptTemplate.from_template(template)
|
| 88 |
-
|
| 89 |
-
# Initialize Groq model (alternative: OpenAI)
|
| 90 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
| 91 |
-
|
| 92 |
-
#
|
|
|
|
|
|
|
|
|
|
| 93 |
chain = (
|
| 94 |
-
RunnablePassthrough.assign(query=sql_chain)
|
| 95 |
-
|
| 96 |
-
|
|
|
|
| 97 |
)
|
| 98 |
-
| prompt # Use prompt
|
| 99 |
-
| llm #
|
| 100 |
-
| StrOutputParser() # Parse the
|
| 101 |
)
|
| 102 |
-
|
| 103 |
-
#
|
| 104 |
-
|
| 105 |
"question": user_query,
|
| 106 |
"chat_history": chat_history,
|
| 107 |
})
|
| 108 |
|
| 109 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
if "chat_history" not in st.session_state:
|
| 111 |
-
# Initialize chat history with a welcome message from AI
|
| 112 |
st.session_state.chat_history = [
|
|
|
|
| 113 |
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
|
| 114 |
]
|
| 115 |
|
| 116 |
-
#
|
| 117 |
-
st.set_page_config(page_title="Chat with
|
| 118 |
-
|
| 119 |
-
# Streamlit app title
|
| 120 |
-
st.title("Chat with MySQL")
|
| 121 |
|
| 122 |
-
# Sidebar for database connection settings
|
| 123 |
with st.sidebar:
|
| 124 |
-
st.subheader("Settings")
|
| 125 |
-
st.write("Connect to your database and start chatting.")
|
| 126 |
|
| 127 |
-
#
|
| 128 |
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
|
| 129 |
-
port = st.text_input("Port", value=os.getenv("DB_PORT", "
|
| 130 |
-
user = st.text_input("User", value=os.getenv("DB_USER", "
|
| 131 |
-
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "
|
| 132 |
-
database = st.text_input("Database", value=os.getenv("DB_NAME", "
|
| 133 |
|
| 134 |
-
# Button to
|
| 135 |
if st.button("Connect"):
|
| 136 |
-
with st.spinner("Connecting to database..."):
|
| 137 |
-
#
|
| 138 |
-
db = init_database()
|
| 139 |
if db:
|
| 140 |
-
st.session_state.db = db
|
| 141 |
-
st.success("Connected to the database!")
|
| 142 |
else:
|
| 143 |
-
st.error("Connection failed. Please check your settings.")
|
| 144 |
|
| 145 |
-
# Display chat history
|
| 146 |
for message in st.session_state.chat_history:
|
| 147 |
if isinstance(message, AIMessage):
|
| 148 |
-
# Display AI
|
| 149 |
-
with st.chat_message("AI"):
|
| 150 |
st.markdown(message.content)
|
| 151 |
elif isinstance(message, HumanMessage):
|
| 152 |
-
# Display human
|
| 153 |
-
with st.chat_message("Human"):
|
| 154 |
st.markdown(message.content)
|
| 155 |
|
| 156 |
-
# Input field for user
|
| 157 |
-
user_query = st.chat_input("Type a message...")
|
| 158 |
-
if user_query and user_query.strip():
|
| 159 |
-
#
|
| 160 |
-
st.session_state.chat_history.append(HumanMessage(content=user_query))
|
| 161 |
|
| 162 |
-
# Display user's message in the chat
|
| 163 |
-
with st.chat_message("Human"):
|
| 164 |
st.markdown(user_query)
|
| 165 |
|
| 166 |
-
# Generate and display AI
|
| 167 |
-
|
| 168 |
-
response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
|
| 169 |
st.markdown(response)
|
| 170 |
|
| 171 |
-
# Add AI's response to the chat history
|
| 172 |
-
st.session_state.chat_history.append(AIMessage(content=response))
|
|
|
|
| 1 |
+
# Import necessary libraries and modules for various tasks
|
| 2 |
+
from dotenv import load_dotenv # For loading environment variables from a .env file (such as database credentials)
|
| 3 |
+
from langchain_core.messages import AIMessage, HumanMessage # For handling messages from the AI and user
|
| 4 |
+
from langchain_core.prompts import ChatPromptTemplate # To create templates that will guide the chatbot's responses
|
| 5 |
+
from langchain_core.runnables import RunnablePassthrough # To enable chaining of different operations (like inputs/outputs)
|
| 6 |
+
from langchain_community.utilities import SQLDatabase # A tool to help connect to SQL databases using LangChain
|
| 7 |
+
from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
|
| 8 |
+
from langchain_groq import ChatGroq # This integrates the Groq model for generating chat responses
|
| 9 |
+
import streamlit as st # Streamlit is used for building the web app (user interface)
|
| 10 |
+
import os # To access environment variables (e.g., credentials or other settings)
|
| 11 |
+
import psycopg2 # A PostgreSQL database adapter to enable connections to the database
|
| 12 |
+
|
| 13 |
+
# Load environment variables (such as DB credentials) from the .env file
|
| 14 |
load_dotenv()
|
| 15 |
|
| 16 |
+
# Function to establish a connection to the PostgreSQL database
|
| 17 |
def init_database() -> SQLDatabase:
|
| 18 |
try:
|
| 19 |
+
# Retrieve database connection details from environment variables, or set default values
|
| 20 |
+
user = os.getenv("DB_USER", "postgres")
|
| 21 |
+
password = os.getenv("DB_PASSWORD", "beekhani143")
|
| 22 |
host = os.getenv("DB_HOST", "localhost")
|
| 23 |
+
port = os.getenv("DB_PORT", "5432")
|
| 24 |
+
database = os.getenv("DB_NAME", "")
|
| 25 |
|
| 26 |
+
# Construct the database URI (a URL-like string) with the necessary credentials for PostgreSQL
|
| 27 |
+
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
| 28 |
|
| 29 |
+
# Connect to the database using the SQLDatabase utility and return the instance
|
| 30 |
return SQLDatabase.from_uri(db_uri)
|
| 31 |
except Exception as e:
|
| 32 |
+
# If connection fails, display an error message on the Streamlit UI
|
| 33 |
st.error(f"Failed to connect to database: {e}")
|
| 34 |
return None
|
| 35 |
|
| 36 |
+
# Function to create a process (chain) that generates SQL queries based on user input and previous conversation
|
| 37 |
def get_sql_chain(db):
|
| 38 |
+
# Template to guide how SQL queries are generated. The bot receives table schema and conversation history.
|
| 39 |
template = """
|
| 40 |
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
|
| 41 |
Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
|
|
|
|
| 48 |
SQL Query:
|
| 49 |
"""
|
| 50 |
|
| 51 |
+
# Create a prompt template from the above instructions
|
| 52 |
prompt = ChatPromptTemplate.from_template(template)
|
| 53 |
+
# Initialize the Groq model for generating responses with low randomness (temperature=0 for more deterministic outputs)
|
|
|
|
| 54 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
| 55 |
|
| 56 |
+
# Function to get the schema (structure) of the tables in the database
|
| 57 |
def get_schema(_):
|
| 58 |
return db.get_table_info()
|
| 59 |
|
| 60 |
+
# Create a chain of operations:
|
| 61 |
+
# 1. First, get the database schema.
|
| 62 |
+
# 2. Then, use the prompt template to guide query creation.
|
| 63 |
+
# 3. Finally, parse the output as a plain text SQL query.
|
| 64 |
return (
|
| 65 |
+
RunnablePassthrough.assign(schema=get_schema) # Pass the schema into the chain
|
| 66 |
+
| prompt # Use the prompt template
|
| 67 |
+
| llm # Generate a response using the Groq model
|
| 68 |
+
| StrOutputParser() # Parse the response as a string (SQL query)
|
| 69 |
)
|
| 70 |
|
| 71 |
+
# Function to generate a natural language response based on SQL query and database result
|
| 72 |
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
| 73 |
+
# First, get the SQL chain (responsible for generating SQL queries)
|
| 74 |
sql_chain = get_sql_chain(db)
|
| 75 |
+
|
| 76 |
+
# Template to guide how the AI responds to the user's query based on the SQL results
|
| 77 |
template = """
|
| 78 |
You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
|
| 79 |
<SCHEMA>{schema}</SCHEMA>
|
|
|
|
| 83 |
SQL Response: {response}
|
| 84 |
"""
|
| 85 |
|
| 86 |
+
# Create a new prompt template for generating a response
|
| 87 |
prompt = ChatPromptTemplate.from_template(template)
|
| 88 |
+
# Initialize the Groq model for response generation
|
|
|
|
| 89 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
| 90 |
+
|
| 91 |
+
# Chain the following tasks:
|
| 92 |
+
# 1. Generate the SQL query using the earlier chain.
|
| 93 |
+
# 2. Get the schema and execute the query on the database.
|
| 94 |
+
# 3. Return the natural language response based on the query and its results.
|
| 95 |
chain = (
|
| 96 |
+
RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
|
| 97 |
+
.assign(
|
| 98 |
+
schema=lambda _: db.get_table_info(), # Pass the schema to the next step
|
| 99 |
+
response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the SQL query and clean up backslashes
|
| 100 |
)
|
| 101 |
+
| prompt # Use the prompt template for generating natural language response
|
| 102 |
+
| llm # Generate the final response using the model
|
| 103 |
+
| StrOutputParser() # Parse the output into plain text
|
| 104 |
)
|
| 105 |
+
|
| 106 |
+
# Invoke the chain to generate the final response based on the user query and history
|
| 107 |
+
result = chain.invoke({
|
| 108 |
"question": user_query,
|
| 109 |
"chat_history": chat_history,
|
| 110 |
})
|
| 111 |
|
| 112 |
+
# Debugging: Print the SQL query being executed
|
| 113 |
+
if isinstance(result, str):
|
| 114 |
+
print(f"SQL Query: {result}")
|
| 115 |
+
else:
|
| 116 |
+
sql_query = result.get('query', 'No query generated')
|
| 117 |
+
print(f"SQL Query: {sql_query}")
|
| 118 |
+
|
| 119 |
+
# Return the result (natural language response)
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
+
# Initialize the chat session when Streamlit app starts
|
| 123 |
if "chat_history" not in st.session_state:
|
|
|
|
| 124 |
st.session_state.chat_history = [
|
| 125 |
+
# First message from AI assistant
|
| 126 |
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
|
| 127 |
]
|
| 128 |
|
| 129 |
+
# Streamlit page configuration: Set the page title and icon
|
| 130 |
+
st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
|
| 131 |
+
st.title("Chat with PostgreSQL") # Display title on the webpage
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
# Sidebar configuration for database connection settings
|
| 134 |
with st.sidebar:
|
| 135 |
+
st.subheader("Settings") # Display a heading for the settings section
|
| 136 |
+
st.write("Connect to your database and start chatting.") # Instruction text for users
|
| 137 |
|
| 138 |
+
# Input fields for database connection details (host, port, user, password, and database name)
|
| 139 |
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
|
| 140 |
+
port = st.text_input("Port", value=os.getenv("DB_PORT", "5432"))
|
| 141 |
+
user = st.text_input("User", value=os.getenv("DB_USER", "postgres"))
|
| 142 |
+
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
|
| 143 |
+
database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
|
| 144 |
|
| 145 |
+
# Button to attempt database connection
|
| 146 |
if st.button("Connect"):
|
| 147 |
+
with st.spinner("Connecting to database..."): # Show a spinner while connecting
|
| 148 |
+
db = init_database() # Call the function to connect to the database
|
|
|
|
| 149 |
if db:
|
| 150 |
+
st.session_state.db = db # Save the connection in session state
|
| 151 |
+
st.success("Connected to the database!") # Display success message
|
| 152 |
else:
|
| 153 |
+
st.error("Connection failed. Please check your settings.") # Display error message if connection fails
|
| 154 |
|
| 155 |
+
# Display the chat history (both AI and user messages)
|
| 156 |
for message in st.session_state.chat_history:
|
| 157 |
if isinstance(message, AIMessage):
|
| 158 |
+
with st.chat_message("AI"): # Display AI messages in the chat
|
|
|
|
| 159 |
st.markdown(message.content)
|
| 160 |
elif isinstance(message, HumanMessage):
|
| 161 |
+
with st.chat_message("Human"): # Display human messages in the chat
|
|
|
|
| 162 |
st.markdown(message.content)
|
| 163 |
|
| 164 |
+
# Input field for the user to type their message
|
| 165 |
+
user_query = st.chat_input("Type a message...") # Field to capture user query
|
| 166 |
+
if user_query and user_query.strip(): # If the user entered a valid query
|
| 167 |
+
st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save the user query in chat history
|
|
|
|
| 168 |
|
| 169 |
+
with st.chat_message("Human"): # Display the user's message in the chat
|
|
|
|
| 170 |
st.markdown(user_query)
|
| 171 |
|
| 172 |
+
with st.chat_message("AI"): # Generate and display the AI response
|
| 173 |
+
response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get the AI's response
|
|
|
|
| 174 |
st.markdown(response)
|
| 175 |
|
| 176 |
+
st.session_state.chat_history.append(AIMessage(content=response)) # Add the AI's response to the chat history
|
|
|