File size: 8,981 Bytes
0dd285c
 
 
 
 
 
 
 
 
 
 
 
 
5b26f53
 
0dd285c
5b26f53
 
0dd285c
 
 
5b26f53
0dd285c
 
5b26f53
0dd285c
 
5b26f53
0dd285c
5b26f53
 
0dd285c
5b26f53
 
ec96023
0dd285c
efb8ba7
0dd285c
5b26f53
5bea3fb
 
5b26f53
5bea3fb
 
5b26f53
5bea3fb
efb8ba7
 
 
 
0dd285c
5b26f53
0dd285c
5b26f53
 
0dd285c
5b26f53
 
 
0dd285c
 
 
 
5b26f53
0dd285c
 
 
 
5b26f53
 
0dd285c
5bea3fb
0dd285c
5b26f53
0dd285c
 
5b26f53
 
5bea3fb
 
 
 
5b26f53
 
 
0dd285c
5b26f53
0dd285c
5b26f53
0dd285c
 
 
 
 
5b26f53
0dd285c
 
 
 
5b26f53
0dd285c
 
 
5bea3fb
0dd285c
 
 
5b26f53
 
 
 
0dd285c
 
 
 
 
 
 
 
 
 
 
5bea3fb
 
0dd285c
5b26f53
5bea3fb
 
0dd285c
 
 
c796379
0dd285c
7df1f40
0dd285c
 
5bea3fb
0dd285c
5b26f53
0dd285c
 
 
 
7df1f40
0dd285c
ec96023
0dd285c
 
5b26f53
0dd285c
 
5b26f53
0dd285c
5b26f53
0dd285c
411b037
 
0dd285c
5bea3fb
411b037
0dd285c
5bea3fb
411b037
0dd285c
 
 
 
5bea3fb
0dd285c
c796379
5bea3fb
0dd285c
 
5bea3fb
c796379
0dd285c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Import necessary libraries and modules for various tasks
from dotenv import load_dotenv  # For loading environment variables from a .env file (such as database credentials)
from langchain_core.messages import AIMessage, HumanMessage  # For handling messages from the AI and user
from langchain_core.prompts import ChatPromptTemplate  # To create templates that will guide the chatbot's responses
from langchain_core.runnables import RunnablePassthrough  # To enable chaining of different operations (like inputs/outputs)
from langchain_community.utilities import SQLDatabase  # A tool to help connect to SQL databases using LangChain
from langchain_core.output_parsers import StrOutputParser  # To parse outputs into plain text
from langchain_groq import ChatGroq  # This integrates the Groq model for generating chat responses
import streamlit as st  # Streamlit is used for building the web app (user interface)
import os  # To access environment variables (e.g., credentials or other settings)
import psycopg2  # A PostgreSQL database adapter to enable connections to the database

# Load environment variables (such as DB credentials) from the .env file
load_dotenv()

# Function to establish a connection to the PostgreSQL database
def init_database() -> SQLDatabase:
    try:
        # Retrieve database connection details from environment variables, or set default values
        user = os.getenv("DB_USER", "postgres")
        password = os.getenv("DB_PASSWORD", "beekhani143")
        host = os.getenv("DB_HOST", "localhost")
        port = os.getenv("DB_PORT", "5432")
        database = os.getenv("DB_NAME", "")

        # Construct the database URI (a URL-like string) with the necessary credentials for PostgreSQL
        db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
        
        # Connect to the database using the SQLDatabase utility and return the instance
        return SQLDatabase.from_uri(db_uri)
    except Exception as e:
        # If connection fails, display an error message on the Streamlit UI
        st.error(f"Failed to connect to database: {e}")
        return None

# Function to create a process (chain) that generates SQL queries based on user input and previous conversation
def get_sql_chain(db):
    # Template to guide how SQL queries are generated. The bot receives table schema and conversation history.
    template = """
    You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
    Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.

    <SCHEMA>{schema}</SCHEMA>
    Conversation History: {chat_history}
    Write only the SQL query and nothing else.
    
    Question: {question}
    SQL Query:
    """
    
    # Create a prompt template from the above instructions
    prompt = ChatPromptTemplate.from_template(template)
    # Initialize the Groq model for generating responses with low randomness (temperature=0 for more deterministic outputs)
    llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
    
    # Function to get the schema (structure) of the tables in the database
    def get_schema(_):
        return db.get_table_info()
    
    # Create a chain of operations: 
    # 1. First, get the database schema.
    # 2. Then, use the prompt template to guide query creation.
    # 3. Finally, parse the output as a plain text SQL query.
    return (
        RunnablePassthrough.assign(schema=get_schema)  # Pass the schema into the chain
        | prompt  # Use the prompt template
        | llm  # Generate a response using the Groq model
        | StrOutputParser()  # Parse the response as a string (SQL query)
    )

# Function to generate a natural language response based on SQL query and database result
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
    # First, get the SQL chain (responsible for generating SQL queries)
    sql_chain = get_sql_chain(db)

    # Template to guide how the AI responds to the user's query based on the SQL results
    template = """
    You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
    <SCHEMA>{schema}</SCHEMA>
    Conversation History: {chat_history}
    SQL Query: <SQL>{query}</SQL>
    User question: {question}
    SQL Response: {response}
    """
    
    # Create a new prompt template for generating a response
    prompt = ChatPromptTemplate.from_template(template)
    # Initialize the Groq model for response generation
    llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)

    # Chain the following tasks:
    # 1. Generate the SQL query using the earlier chain.
    # 2. Get the schema and execute the query on the database.
    # 3. Return the natural language response based on the query and its results.
    chain = (
        RunnablePassthrough.assign(query=sql_chain)  # Generate SQL query
        .assign(
            schema=lambda _: db.get_table_info(),  # Pass the schema to the next step
            response=lambda vars: db.run(vars["query"].replace("\\", "")),  # Execute the SQL query and clean up backslashes
        )
        | prompt  # Use the prompt template for generating natural language response
        | llm  # Generate the final response using the model
        | StrOutputParser()  # Parse the output into plain text
    )

    # Invoke the chain to generate the final response based on the user query and history
    result = chain.invoke({
        "question": user_query,
        "chat_history": chat_history,
    })

    # Debugging: Print the SQL query being executed
    if isinstance(result, str):
        print(f"SQL Query: {result}")
    else:
        sql_query = result.get('query', 'No query generated')
        print(f"SQL Query: {sql_query}")

    # Return the result (natural language response)
    return result

# Initialize the chat session when Streamlit app starts
if "chat_history" not in st.session_state:
    st.session_state.chat_history = [
        # First message from AI assistant
        AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
    ]

# Streamlit page configuration: Set the page title and icon
st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
st.title("Chat with PostgreSQL")  # Display title on the webpage

# Sidebar configuration for database connection settings
with st.sidebar:
    st.subheader("Settings")  # Display a heading for the settings section
    st.write("Connect to your database and start chatting.")  # Instruction text for users
    
    # Input fields for database connection details (host, port, user, password, and database name)
    host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
    port = st.text_input("Port", value=os.getenv("DB_PORT", "5432"))
    user = st.text_input("User", value=os.getenv("DB_USER", "postgres"))
    password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
    database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
    
    # Button to attempt database connection
    if st.button("Connect"):
        with st.spinner("Connecting to database..."):  # Show a spinner while connecting
            db = init_database()  # Call the function to connect to the database
            if db:
                st.session_state.db = db  # Save the connection in session state
                st.success("Connected to the database!")  # Display success message
            else:
                st.error("Connection failed. Please check your settings.")  # Display error message if connection fails

# Display the chat history (both AI and user messages)
for message in st.session_state.chat_history:
    if isinstance(message, AIMessage):
        with st.chat_message("AI"):  # Display AI messages in the chat
            st.markdown(message.content)
    elif isinstance(message, HumanMessage):
        with st.chat_message("Human"):  # Display human messages in the chat
            st.markdown(message.content)

# Input field for the user to type their message
user_query = st.chat_input("Type a message...")  # Field to capture user query
if user_query and user_query.strip():  # If the user entered a valid query
    st.session_state.chat_history.append(HumanMessage(content=user_query))  # Save the user query in chat history
    
    with st.chat_message("Human"):  # Display the user's message in the chat
        st.markdown(user_query)
        
    with st.chat_message("AI"):  # Generate and display the AI response
        response = get_response(user_query, st.session_state.db, st.session_state.chat_history)  # Get the AI's response
        st.markdown(response)
        
    st.session_state.chat_history.append(AIMessage(content=response))  # Add the AI's response to the chat history