File size: 6,744 Bytes
4709845
 
 
 
 
 
5780b29
4709845
 
 
 
5780b29
4709845
5780b29
 
 
297c326
4709845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5780b29
 
 
 
 
 
 
 
 
4709845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5780b29
4709845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297c326
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import streamlit as st
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.chat_models import init_chat_model
from langchain_community.agent_toolkits import create_sql_agent
from langchain.callbacks.base import BaseCallbackHandler
import time

# Database connection string
url = "postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres"

# Initialize session state variables if they don't exist
if "db" not in st.session_state:
    st.session_state.db = SQLDatabase.from_uri(
        "postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres"
    )

if "agent_chain_output" not in st.session_state:
    st.session_state.agent_chain_output = ""

# Custom callback handler for streaming output
class StreamHandler(BaseCallbackHandler):
    def __init__(self, container):
        self.container = container
        self.text = ""
        
    def on_llm_start(self, serialized, prompts, **kwargs):
        self.text += "🧠 Starting to think...\n\n"
        self.container.markdown(self.text)
    
    def on_llm_new_token(self, token, **kwargs):
        self.text += token
        self.container.markdown(self.text)
    
    def on_llm_end(self, response, **kwargs):
        self.text += "\n\nβœ… Thinking complete.\n\n"
        self.container.markdown(self.text)
    
    def on_tool_start(self, serialized, input_str, **kwargs):
        self.text += f"πŸ”§ Using tool: {serialized['name']}\n"
        self.text += f"Tool input: {input_str}\n\n"
        self.container.markdown(self.text)
    
    def on_tool_end(self, output, **kwargs):
        self.text += f"Tool output: {output}\n\n"
        self.container.markdown(self.text)
    
    def on_chain_start(self, serialized, inputs, **kwargs):
        chain_type = serialized.get("name", "Chain")
        self.text += f"⛓️ Starting {chain_type}...\n"
        self.container.markdown(self.text)
    
    def on_chain_end(self, outputs, **kwargs):
        self.text += f"⛓️ Chain complete.\n\n"
        self.container.markdown(self.text)
    
    def on_agent_action(self, action, **kwargs):
        self.text += f"πŸ€– Agent action: {action.tool}\n"
        self.text += f"Action input: {action.tool_input}\n\n"
        self.container.markdown(self.text)
    
    def on_agent_finish(self, finish, **kwargs):
        self.text += f"🏁 Agent finished: {finish.return_values.get('output')}\n\n"
        self.container.markdown(self.text)

# Main function to run the app
def main():
    st.title("πŸ€– Database Query Assistant")
    st.write("Ask questions about your database in natural language.")
    
    # Sidebar for database information and settings
    with st.sidebar:
        st.header("Database Information")
        st.write(f"Connected to: PostgreSQL database")
        
        # Options
        show_sql = st.checkbox("Show SQL Queries", value=True)
        show_thinking = st.checkbox("Show Agent Thinking", value=True)
        
        # Database schema button
        if st.button("Show Database Schema"):
            with st.spinner("Fetching database schema..."):
                schema = st.session_state.db.get_table_info()
                st.code(schema)
    
    # Query input section
    query_options = ["Free-form query", "Get employee details"]
    query_type = st.radio("Query Type:", query_options)
    
    if query_type == "Free-form query":
        query = st.text_area("Enter your query:", height=100, 
                            placeholder="Example: What are the top 3 highest paid employees?")
    else:
        employee_id = st.text_input("Enter Employee ID:", placeholder="Example: 1001")
        query = f"Give details of employee ID {employee_id}" if employee_id else ""
    
    # Process button
    col1, col2 = st.columns([1, 5])
    with col1:
        process_button = st.button("Run Query", type="primary", use_container_width=True)
    with col2:
        if process_button:
            clear_button = st.button("Clear Results", use_container_width=True)
            if clear_button:
                st.session_state.agent_chain_output = ""
                st.experimental_rerun()
    
    # Results section
    st.header("Results")
    
    # Create containers for output
    if show_thinking:
        thinking_container = st.expander("Agent Thinking Process", expanded=True)
        thinking_output = thinking_container.empty()
    
    result_container = st.container()
    result_output = result_container.empty()
    
    # Process the query when button is clicked
    if process_button and query:
        # Set up the streaming handlers
        if show_thinking:
            stream_handler = StreamHandler(thinking_output)
        else:
            stream_handler = None
        
        # Initialize LLM
        try:
            # Define a fixed API key (hardcoded for simplicity)
            api_key = "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"
            
            # Create a new LLM instance
            llm = init_chat_model(
                "llama-3.3-70b-versatile", 
                model_provider="groq",
                api_key=api_key,
                streaming=True,
                callbacks=[stream_handler] if stream_handler else None
            )
            
            # Create the agent with callbacks
            toolkit = SQLDatabaseToolkit(db=st.session_state.db, llm=llm)
            agent_executor = create_sql_agent(
                llm, 
                toolkit=toolkit, 
                verbose=show_thinking,
                callbacks=[stream_handler] if stream_handler else None
            )
            
            # Process the query and display the result
            with st.spinner("Processing your query..."):
                result = agent_executor.invoke(query)
                
                # Display the result
                result_container.success("Query processed successfully!")
                result_output.markdown("### Answer")
                result_output.write(result["output"])
                
                # Show the SQL if requested
                if show_sql and "intermediate_steps" in result:
                    sql_container = st.expander("SQL Queries Used", expanded=True)
                    for step in result["intermediate_steps"]:
                        if hasattr(step[0], 'tool') and step[0].tool == "sql_db_query":
                            sql_container.code(step[0].tool_input, language="sql")
        
        except Exception as e:
            st.error(f"An error occurred: {str(e)}")

if __name__ == "__main__":
    main()