Suresh Beekhani
Update app.py
0dd285c unverified
raw
history blame
8.98 kB
# Import necessary libraries and modules for various tasks
from dotenv import load_dotenv # For loading environment variables from a .env file (such as database credentials)
from langchain_core.messages import AIMessage, HumanMessage # For handling messages from the AI and user
from langchain_core.prompts import ChatPromptTemplate # To create templates that will guide the chatbot's responses
from langchain_core.runnables import RunnablePassthrough # To enable chaining of different operations (like inputs/outputs)
from langchain_community.utilities import SQLDatabase # A tool to help connect to SQL databases using LangChain
from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
from langchain_groq import ChatGroq # This integrates the Groq model for generating chat responses
import streamlit as st # Streamlit is used for building the web app (user interface)
import os # To access environment variables (e.g., credentials or other settings)
import psycopg2 # A PostgreSQL database adapter to enable connections to the database
# Load environment variables (such as DB credentials) from the .env file
load_dotenv()
# Function to establish a connection to the PostgreSQL database
def init_database() -> SQLDatabase:
try:
# Retrieve database connection details from environment variables, or set default values
user = os.getenv("DB_USER", "postgres")
password = os.getenv("DB_PASSWORD", "beekhani143")
host = os.getenv("DB_HOST", "localhost")
port = os.getenv("DB_PORT", "5432")
database = os.getenv("DB_NAME", "")
# Construct the database URI (a URL-like string) with the necessary credentials for PostgreSQL
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
# Connect to the database using the SQLDatabase utility and return the instance
return SQLDatabase.from_uri(db_uri)
except Exception as e:
# If connection fails, display an error message on the Streamlit UI
st.error(f"Failed to connect to database: {e}")
return None
# Function to create a process (chain) that generates SQL queries based on user input and previous conversation
def get_sql_chain(db):
# Template to guide how SQL queries are generated. The bot receives table schema and conversation history.
template = """
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
Write only the SQL query and nothing else.
Question: {question}
SQL Query:
"""
# Create a prompt template from the above instructions
prompt = ChatPromptTemplate.from_template(template)
# Initialize the Groq model for generating responses with low randomness (temperature=0 for more deterministic outputs)
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
# Function to get the schema (structure) of the tables in the database
def get_schema(_):
return db.get_table_info()
# Create a chain of operations:
# 1. First, get the database schema.
# 2. Then, use the prompt template to guide query creation.
# 3. Finally, parse the output as a plain text SQL query.
return (
RunnablePassthrough.assign(schema=get_schema) # Pass the schema into the chain
| prompt # Use the prompt template
| llm # Generate a response using the Groq model
| StrOutputParser() # Parse the response as a string (SQL query)
)
# Function to generate a natural language response based on SQL query and database result
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
# First, get the SQL chain (responsible for generating SQL queries)
sql_chain = get_sql_chain(db)
# Template to guide how the AI responds to the user's query based on the SQL results
template = """
You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
SQL Query: <SQL>{query}</SQL>
User question: {question}
SQL Response: {response}
"""
# Create a new prompt template for generating a response
prompt = ChatPromptTemplate.from_template(template)
# Initialize the Groq model for response generation
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
# Chain the following tasks:
# 1. Generate the SQL query using the earlier chain.
# 2. Get the schema and execute the query on the database.
# 3. Return the natural language response based on the query and its results.
chain = (
RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
.assign(
schema=lambda _: db.get_table_info(), # Pass the schema to the next step
response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the SQL query and clean up backslashes
)
| prompt # Use the prompt template for generating natural language response
| llm # Generate the final response using the model
| StrOutputParser() # Parse the output into plain text
)
# Invoke the chain to generate the final response based on the user query and history
result = chain.invoke({
"question": user_query,
"chat_history": chat_history,
})
# Debugging: Print the SQL query being executed
if isinstance(result, str):
print(f"SQL Query: {result}")
else:
sql_query = result.get('query', 'No query generated')
print(f"SQL Query: {sql_query}")
# Return the result (natural language response)
return result
# Initialize the chat session when Streamlit app starts
if "chat_history" not in st.session_state:
st.session_state.chat_history = [
# First message from AI assistant
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
]
# Streamlit page configuration: Set the page title and icon
st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
st.title("Chat with PostgreSQL") # Display title on the webpage
# Sidebar configuration for database connection settings
with st.sidebar:
st.subheader("Settings") # Display a heading for the settings section
st.write("Connect to your database and start chatting.") # Instruction text for users
# Input fields for database connection details (host, port, user, password, and database name)
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
port = st.text_input("Port", value=os.getenv("DB_PORT", "5432"))
user = st.text_input("User", value=os.getenv("DB_USER", "postgres"))
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
# Button to attempt database connection
if st.button("Connect"):
with st.spinner("Connecting to database..."): # Show a spinner while connecting
db = init_database() # Call the function to connect to the database
if db:
st.session_state.db = db # Save the connection in session state
st.success("Connected to the database!") # Display success message
else:
st.error("Connection failed. Please check your settings.") # Display error message if connection fails
# Display the chat history (both AI and user messages)
for message in st.session_state.chat_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"): # Display AI messages in the chat
st.markdown(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Human"): # Display human messages in the chat
st.markdown(message.content)
# Input field for the user to type their message
user_query = st.chat_input("Type a message...") # Field to capture user query
if user_query and user_query.strip(): # If the user entered a valid query
st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save the user query in chat history
with st.chat_message("Human"): # Display the user's message in the chat
st.markdown(user_query)
with st.chat_message("AI"): # Generate and display the AI response
response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get the AI's response
st.markdown(response)
st.session_state.chat_history.append(AIMessage(content=response)) # Add the AI's response to the chat history