import streamlit as st from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain.chat_models import init_chat_model from langchain_community.agent_toolkits import create_sql_agent from langchain.callbacks.base import BaseCallbackHandler import time # Database connection string url = "postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres" # Initialize session state variables if they don't exist if "db" not in st.session_state: st.session_state.db = SQLDatabase.from_uri( "postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres" ) if "agent_chain_output" not in st.session_state: st.session_state.agent_chain_output = "" # Custom callback handler for streaming output class StreamHandler(BaseCallbackHandler): def __init__(self, container): self.container = container self.text = "" def on_llm_start(self, serialized, prompts, **kwargs): self.text += "🧠 Starting to think...\n\n" self.container.markdown(self.text) def on_llm_new_token(self, token, **kwargs): self.text += token self.container.markdown(self.text) def on_llm_end(self, response, **kwargs): self.text += "\n\nāœ… Thinking complete.\n\n" self.container.markdown(self.text) def on_tool_start(self, serialized, input_str, **kwargs): self.text += f"šŸ”§ Using tool: {serialized['name']}\n" self.text += f"Tool input: {input_str}\n\n" self.container.markdown(self.text) def on_tool_end(self, output, **kwargs): self.text += f"Tool output: {output}\n\n" self.container.markdown(self.text) def on_chain_start(self, serialized, inputs, **kwargs): chain_type = serialized.get("name", "Chain") self.text += f"ā›“ļø Starting {chain_type}...\n" self.container.markdown(self.text) def on_chain_end(self, outputs, **kwargs): self.text += f"ā›“ļø Chain complete.\n\n" self.container.markdown(self.text) def on_agent_action(self, action, **kwargs): self.text += f"šŸ¤– Agent action: {action.tool}\n" self.text += f"Action input: {action.tool_input}\n\n" self.container.markdown(self.text) def on_agent_finish(self, finish, **kwargs): self.text += f"šŸ Agent finished: {finish.return_values.get('output')}\n\n" self.container.markdown(self.text) # Main function to run the app def main(): st.title("šŸ¤– Database Query Assistant") st.write("Ask questions about your database in natural language.") # Sidebar for database information and settings with st.sidebar: st.header("Database Information") st.write(f"Connected to: PostgreSQL database") # Options show_sql = st.checkbox("Show SQL Queries", value=True) show_thinking = st.checkbox("Show Agent Thinking", value=True) # Database schema button if st.button("Show Database Schema"): with st.spinner("Fetching database schema..."): schema = st.session_state.db.get_table_info() st.code(schema) # Query input section query_options = ["Free-form query", "Get employee details"] query_type = st.radio("Query Type:", query_options) if query_type == "Free-form query": query = st.text_area("Enter your query:", height=100, placeholder="Example: What are the top 3 highest paid employees?") else: employee_id = st.text_input("Enter Employee ID:", placeholder="Example: 1001") query = f"Give details of employee ID {employee_id}" if employee_id else "" # Process button col1, col2 = st.columns([1, 5]) with col1: process_button = st.button("Run Query", type="primary", use_container_width=True) with col2: if process_button: clear_button = st.button("Clear Results", use_container_width=True) if clear_button: st.session_state.agent_chain_output = "" st.experimental_rerun() # Results section st.header("Results") # Create containers for output if show_thinking: thinking_container = st.expander("Agent Thinking Process", expanded=True) thinking_output = thinking_container.empty() result_container = st.container() result_output = result_container.empty() # Process the query when button is clicked if process_button and query: # Set up the streaming handlers if show_thinking: stream_handler = StreamHandler(thinking_output) else: stream_handler = None # Initialize LLM try: # Define a fixed API key (hardcoded for simplicity) api_key = "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" # Create a new LLM instance llm = init_chat_model( "llama-3.3-70b-versatile", model_provider="groq", api_key=api_key, streaming=True, callbacks=[stream_handler] if stream_handler else None ) # Create the agent with callbacks toolkit = SQLDatabaseToolkit(db=st.session_state.db, llm=llm) agent_executor = create_sql_agent( llm, toolkit=toolkit, verbose=show_thinking, callbacks=[stream_handler] if stream_handler else None ) # Process the query and display the result with st.spinner("Processing your query..."): result = agent_executor.invoke(query) # Display the result result_container.success("Query processed successfully!") result_output.markdown("### Answer") result_output.write(result["output"]) # Show the SQL if requested if show_sql and "intermediate_steps" in result: sql_container = st.expander("SQL Queries Used", expanded=True) for step in result["intermediate_steps"]: if hasattr(step[0], 'tool') and step[0].tool == "sql_db_query": sql_container.code(step[0].tool_input, language="sql") except Exception as e: st.error(f"An error occurred: {str(e)}") if __name__ == "__main__": main()