cap_frontend / chatbot.py
logeswari's picture
changes made
5780b29
import streamlit as st
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.chat_models import init_chat_model
from langchain_community.agent_toolkits import create_sql_agent
from langchain.callbacks.base import BaseCallbackHandler
import time
# Database connection string
url = "postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres"
# Initialize session state variables if they don't exist
if "db" not in st.session_state:
st.session_state.db = SQLDatabase.from_uri(
"postgresql://postgres.qxvpaoeakhddzabctekw:8&CiDRpTFbRRBrT@aws-0-ap-south-1.pooler.supabase.com:5432/postgres"
)
if "agent_chain_output" not in st.session_state:
st.session_state.agent_chain_output = ""
# Custom callback handler for streaming output
class StreamHandler(BaseCallbackHandler):
def __init__(self, container):
self.container = container
self.text = ""
def on_llm_start(self, serialized, prompts, **kwargs):
self.text += "🧠 Starting to think...\n\n"
self.container.markdown(self.text)
def on_llm_new_token(self, token, **kwargs):
self.text += token
self.container.markdown(self.text)
def on_llm_end(self, response, **kwargs):
self.text += "\n\nβœ… Thinking complete.\n\n"
self.container.markdown(self.text)
def on_tool_start(self, serialized, input_str, **kwargs):
self.text += f"πŸ”§ Using tool: {serialized['name']}\n"
self.text += f"Tool input: {input_str}\n\n"
self.container.markdown(self.text)
def on_tool_end(self, output, **kwargs):
self.text += f"Tool output: {output}\n\n"
self.container.markdown(self.text)
def on_chain_start(self, serialized, inputs, **kwargs):
chain_type = serialized.get("name", "Chain")
self.text += f"⛓️ Starting {chain_type}...\n"
self.container.markdown(self.text)
def on_chain_end(self, outputs, **kwargs):
self.text += f"⛓️ Chain complete.\n\n"
self.container.markdown(self.text)
def on_agent_action(self, action, **kwargs):
self.text += f"πŸ€– Agent action: {action.tool}\n"
self.text += f"Action input: {action.tool_input}\n\n"
self.container.markdown(self.text)
def on_agent_finish(self, finish, **kwargs):
self.text += f"🏁 Agent finished: {finish.return_values.get('output')}\n\n"
self.container.markdown(self.text)
# Main function to run the app
def main():
st.title("πŸ€– Database Query Assistant")
st.write("Ask questions about your database in natural language.")
# Sidebar for database information and settings
with st.sidebar:
st.header("Database Information")
st.write(f"Connected to: PostgreSQL database")
# Options
show_sql = st.checkbox("Show SQL Queries", value=True)
show_thinking = st.checkbox("Show Agent Thinking", value=True)
# Database schema button
if st.button("Show Database Schema"):
with st.spinner("Fetching database schema..."):
schema = st.session_state.db.get_table_info()
st.code(schema)
# Query input section
query_options = ["Free-form query", "Get employee details"]
query_type = st.radio("Query Type:", query_options)
if query_type == "Free-form query":
query = st.text_area("Enter your query:", height=100,
placeholder="Example: What are the top 3 highest paid employees?")
else:
employee_id = st.text_input("Enter Employee ID:", placeholder="Example: 1001")
query = f"Give details of employee ID {employee_id}" if employee_id else ""
# Process button
col1, col2 = st.columns([1, 5])
with col1:
process_button = st.button("Run Query", type="primary", use_container_width=True)
with col2:
if process_button:
clear_button = st.button("Clear Results", use_container_width=True)
if clear_button:
st.session_state.agent_chain_output = ""
st.experimental_rerun()
# Results section
st.header("Results")
# Create containers for output
if show_thinking:
thinking_container = st.expander("Agent Thinking Process", expanded=True)
thinking_output = thinking_container.empty()
result_container = st.container()
result_output = result_container.empty()
# Process the query when button is clicked
if process_button and query:
# Set up the streaming handlers
if show_thinking:
stream_handler = StreamHandler(thinking_output)
else:
stream_handler = None
# Initialize LLM
try:
# Define a fixed API key (hardcoded for simplicity)
api_key = "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"
# Create a new LLM instance
llm = init_chat_model(
"llama-3.3-70b-versatile",
model_provider="groq",
api_key=api_key,
streaming=True,
callbacks=[stream_handler] if stream_handler else None
)
# Create the agent with callbacks
toolkit = SQLDatabaseToolkit(db=st.session_state.db, llm=llm)
agent_executor = create_sql_agent(
llm,
toolkit=toolkit,
verbose=show_thinking,
callbacks=[stream_handler] if stream_handler else None
)
# Process the query and display the result
with st.spinner("Processing your query..."):
result = agent_executor.invoke(query)
# Display the result
result_container.success("Query processed successfully!")
result_output.markdown("### Answer")
result_output.write(result["output"])
# Show the SQL if requested
if show_sql and "intermediate_steps" in result:
sql_container = st.expander("SQL Queries Used", expanded=True)
for step in result["intermediate_steps"]:
if hasattr(step[0], 'tool') and step[0].tool == "sql_db_query":
sql_container.code(step[0].tool_input, language="sql")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()