Spaces:
Build error
Build error
File size: 6,744 Bytes
4709845 5780b29 4709845 5780b29 4709845 5780b29 297c326 4709845 5780b29 4709845 5780b29 4709845 297c326 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import streamlit as st
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.chat_models import init_chat_model
from langchain_community.agent_toolkits import create_sql_agent
from langchain.callbacks.base import BaseCallbackHandler
import time
# Database connection string
url = "postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres"
# Initialize session state variables if they don't exist
if "db" not in st.session_state:
st.session_state.db = SQLDatabase.from_uri(
"postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres"
)
if "agent_chain_output" not in st.session_state:
st.session_state.agent_chain_output = ""
# Custom callback handler for streaming output
class StreamHandler(BaseCallbackHandler):
def __init__(self, container):
self.container = container
self.text = ""
def on_llm_start(self, serialized, prompts, **kwargs):
self.text += "π§ Starting to think...\n\n"
self.container.markdown(self.text)
def on_llm_new_token(self, token, **kwargs):
self.text += token
self.container.markdown(self.text)
def on_llm_end(self, response, **kwargs):
self.text += "\n\nβ
Thinking complete.\n\n"
self.container.markdown(self.text)
def on_tool_start(self, serialized, input_str, **kwargs):
self.text += f"π§ Using tool: {serialized['name']}\n"
self.text += f"Tool input: {input_str}\n\n"
self.container.markdown(self.text)
def on_tool_end(self, output, **kwargs):
self.text += f"Tool output: {output}\n\n"
self.container.markdown(self.text)
def on_chain_start(self, serialized, inputs, **kwargs):
chain_type = serialized.get("name", "Chain")
self.text += f"βοΈ Starting {chain_type}...\n"
self.container.markdown(self.text)
def on_chain_end(self, outputs, **kwargs):
self.text += f"βοΈ Chain complete.\n\n"
self.container.markdown(self.text)
def on_agent_action(self, action, **kwargs):
self.text += f"π€ Agent action: {action.tool}\n"
self.text += f"Action input: {action.tool_input}\n\n"
self.container.markdown(self.text)
def on_agent_finish(self, finish, **kwargs):
self.text += f"π Agent finished: {finish.return_values.get('output')}\n\n"
self.container.markdown(self.text)
# Main function to run the app
def main():
st.title("π€ Database Query Assistant")
st.write("Ask questions about your database in natural language.")
# Sidebar for database information and settings
with st.sidebar:
st.header("Database Information")
st.write(f"Connected to: PostgreSQL database")
# Options
show_sql = st.checkbox("Show SQL Queries", value=True)
show_thinking = st.checkbox("Show Agent Thinking", value=True)
# Database schema button
if st.button("Show Database Schema"):
with st.spinner("Fetching database schema..."):
schema = st.session_state.db.get_table_info()
st.code(schema)
# Query input section
query_options = ["Free-form query", "Get employee details"]
query_type = st.radio("Query Type:", query_options)
if query_type == "Free-form query":
query = st.text_area("Enter your query:", height=100,
placeholder="Example: What are the top 3 highest paid employees?")
else:
employee_id = st.text_input("Enter Employee ID:", placeholder="Example: 1001")
query = f"Give details of employee ID {employee_id}" if employee_id else ""
# Process button
col1, col2 = st.columns([1, 5])
with col1:
process_button = st.button("Run Query", type="primary", use_container_width=True)
with col2:
if process_button:
clear_button = st.button("Clear Results", use_container_width=True)
if clear_button:
st.session_state.agent_chain_output = ""
st.experimental_rerun()
# Results section
st.header("Results")
# Create containers for output
if show_thinking:
thinking_container = st.expander("Agent Thinking Process", expanded=True)
thinking_output = thinking_container.empty()
result_container = st.container()
result_output = result_container.empty()
# Process the query when button is clicked
if process_button and query:
# Set up the streaming handlers
if show_thinking:
stream_handler = StreamHandler(thinking_output)
else:
stream_handler = None
# Initialize LLM
try:
# Define a fixed API key (hardcoded for simplicity)
api_key = "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"
# Create a new LLM instance
llm = init_chat_model(
"llama-3.3-70b-versatile",
model_provider="groq",
api_key=api_key,
streaming=True,
callbacks=[stream_handler] if stream_handler else None
)
# Create the agent with callbacks
toolkit = SQLDatabaseToolkit(db=st.session_state.db, llm=llm)
agent_executor = create_sql_agent(
llm,
toolkit=toolkit,
verbose=show_thinking,
callbacks=[stream_handler] if stream_handler else None
)
# Process the query and display the result
with st.spinner("Processing your query..."):
result = agent_executor.invoke(query)
# Display the result
result_container.success("Query processed successfully!")
result_output.markdown("### Answer")
result_output.write(result["output"])
# Show the SQL if requested
if show_sql and "intermediate_steps" in result:
sql_container = st.expander("SQL Queries Used", expanded=True)
for step in result["intermediate_steps"]:
if hasattr(step[0], 'tool') and step[0].tool == "sql_db_query":
sql_container.code(step[0].tool_input, language="sql")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main() |