Spaces:
Build error
Build error
| import streamlit as st | |
| from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit | |
| from langchain_community.utilities import SQLDatabase | |
| from langchain.chat_models import init_chat_model | |
| from langchain_community.agent_toolkits import create_sql_agent | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| import time | |
| # Database connection string | |
| url = "postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres" | |
| # Initialize session state variables if they don't exist | |
| if "db" not in st.session_state: | |
| st.session_state.db = SQLDatabase.from_uri( | |
| "postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres" | |
| ) | |
| if "agent_chain_output" not in st.session_state: | |
| st.session_state.agent_chain_output = "" | |
| # Custom callback handler for streaming output | |
| class StreamHandler(BaseCallbackHandler): | |
| def __init__(self, container): | |
| self.container = container | |
| self.text = "" | |
| def on_llm_start(self, serialized, prompts, **kwargs): | |
| self.text += "π§ Starting to think...\n\n" | |
| self.container.markdown(self.text) | |
| def on_llm_new_token(self, token, **kwargs): | |
| self.text += token | |
| self.container.markdown(self.text) | |
| def on_llm_end(self, response, **kwargs): | |
| self.text += "\n\nβ Thinking complete.\n\n" | |
| self.container.markdown(self.text) | |
| def on_tool_start(self, serialized, input_str, **kwargs): | |
| self.text += f"π§ Using tool: {serialized['name']}\n" | |
| self.text += f"Tool input: {input_str}\n\n" | |
| self.container.markdown(self.text) | |
| def on_tool_end(self, output, **kwargs): | |
| self.text += f"Tool output: {output}\n\n" | |
| self.container.markdown(self.text) | |
| def on_chain_start(self, serialized, inputs, **kwargs): | |
| chain_type = serialized.get("name", "Chain") | |
| self.text += f"βοΈ Starting {chain_type}...\n" | |
| self.container.markdown(self.text) | |
| def on_chain_end(self, outputs, **kwargs): | |
| self.text += f"βοΈ Chain complete.\n\n" | |
| self.container.markdown(self.text) | |
| def on_agent_action(self, action, **kwargs): | |
| self.text += f"π€ Agent action: {action.tool}\n" | |
| self.text += f"Action input: {action.tool_input}\n\n" | |
| self.container.markdown(self.text) | |
| def on_agent_finish(self, finish, **kwargs): | |
| self.text += f"π Agent finished: {finish.return_values.get('output')}\n\n" | |
| self.container.markdown(self.text) | |
| # Main function to run the app | |
| def main(): | |
| st.title("π€ Database Query Assistant") | |
| st.write("Ask questions about your database in natural language.") | |
| # Sidebar for database information and settings | |
| with st.sidebar: | |
| st.header("Database Information") | |
| st.write(f"Connected to: PostgreSQL database") | |
| # Options | |
| show_sql = st.checkbox("Show SQL Queries", value=True) | |
| show_thinking = st.checkbox("Show Agent Thinking", value=True) | |
| # Database schema button | |
| if st.button("Show Database Schema"): | |
| with st.spinner("Fetching database schema..."): | |
| schema = st.session_state.db.get_table_info() | |
| st.code(schema) | |
| # Query input section | |
| query_options = ["Free-form query", "Get employee details"] | |
| query_type = st.radio("Query Type:", query_options) | |
| if query_type == "Free-form query": | |
| query = st.text_area("Enter your query:", height=100, | |
| placeholder="Example: What are the top 3 highest paid employees?") | |
| else: | |
| employee_id = st.text_input("Enter Employee ID:", placeholder="Example: 1001") | |
| query = f"Give details of employee ID {employee_id}" if employee_id else "" | |
| # Process button | |
| col1, col2 = st.columns([1, 5]) | |
| with col1: | |
| process_button = st.button("Run Query", type="primary", use_container_width=True) | |
| with col2: | |
| if process_button: | |
| clear_button = st.button("Clear Results", use_container_width=True) | |
| if clear_button: | |
| st.session_state.agent_chain_output = "" | |
| st.experimental_rerun() | |
| # Results section | |
| st.header("Results") | |
| # Create containers for output | |
| if show_thinking: | |
| thinking_container = st.expander("Agent Thinking Process", expanded=True) | |
| thinking_output = thinking_container.empty() | |
| result_container = st.container() | |
| result_output = result_container.empty() | |
| # Process the query when button is clicked | |
| if process_button and query: | |
| # Set up the streaming handlers | |
| if show_thinking: | |
| stream_handler = StreamHandler(thinking_output) | |
| else: | |
| stream_handler = None | |
| # Initialize LLM | |
| try: | |
| # Define a fixed API key (hardcoded for simplicity) | |
| api_key = "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" | |
| # Create a new LLM instance | |
| llm = init_chat_model( | |
| "llama-3.3-70b-versatile", | |
| model_provider="groq", | |
| api_key=api_key, | |
| streaming=True, | |
| callbacks=[stream_handler] if stream_handler else None | |
| ) | |
| # Create the agent with callbacks | |
| toolkit = SQLDatabaseToolkit(db=st.session_state.db, llm=llm) | |
| agent_executor = create_sql_agent( | |
| llm, | |
| toolkit=toolkit, | |
| verbose=show_thinking, | |
| callbacks=[stream_handler] if stream_handler else None | |
| ) | |
| # Process the query and display the result | |
| with st.spinner("Processing your query..."): | |
| result = agent_executor.invoke(query) | |
| # Display the result | |
| result_container.success("Query processed successfully!") | |
| result_output.markdown("### Answer") | |
| result_output.write(result["output"]) | |
| # Show the SQL if requested | |
| if show_sql and "intermediate_steps" in result: | |
| sql_container = st.expander("SQL Queries Used", expanded=True) | |
| for step in result["intermediate_steps"]: | |
| if hasattr(step[0], 'tool') and step[0].tool == "sql_db_query": | |
| sql_container.code(step[0].tool_input, language="sql") | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |