Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,13 +3,11 @@ import streamlit as st
|
|
| 3 |
import pandas as pd
|
| 4 |
import sqlite3
|
| 5 |
import logging
|
| 6 |
-
from langchain.agents import create_sql_agent
|
| 7 |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
| 8 |
from langchain.llms import OpenAI
|
| 9 |
from langchain.sql_database import SQLDatabase
|
| 10 |
-
from langchain.prompts import
|
| 11 |
-
PromptTemplate,
|
| 12 |
-
)
|
| 13 |
from langchain.evaluation import load_evaluator
|
| 14 |
|
| 15 |
# Initialize logging
|
|
@@ -19,7 +17,7 @@ logging.basicConfig(level=logging.INFO)
|
|
| 19 |
if 'history' not in st.session_state:
|
| 20 |
st.session_state.history = []
|
| 21 |
|
| 22 |
-
# OpenAI API key
|
| 23 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 24 |
|
| 25 |
# Check if the API key is set
|
|
@@ -33,7 +31,7 @@ st.write("Upload a CSV file to get started, or use the default dataset.")
|
|
| 33 |
|
| 34 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
| 35 |
if csv_file is None:
|
| 36 |
-
data = pd.read_csv("default_data.csv") # Ensure this file exists
|
| 37 |
st.write("Using default_data.csv file.")
|
| 38 |
table_name = "default_table"
|
| 39 |
else:
|
|
@@ -42,19 +40,19 @@ else:
|
|
| 42 |
st.write(f"Data Preview ({csv_file.name}):")
|
| 43 |
st.dataframe(data.head())
|
| 44 |
|
| 45 |
-
# Step 2: Load CSV data into
|
| 46 |
db_file = 'my_database.db'
|
| 47 |
conn = sqlite3.connect(db_file)
|
| 48 |
data.to_sql(table_name, conn, index=False, if_exists='replace')
|
| 49 |
|
| 50 |
-
# SQL table metadata
|
| 51 |
valid_columns = list(data.columns)
|
| 52 |
st.write(f"Valid columns: {valid_columns}")
|
| 53 |
|
| 54 |
-
# Create SQLDatabase instance
|
| 55 |
engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
|
| 56 |
|
| 57 |
-
# Step 3: Define
|
| 58 |
few_shot_examples = [
|
| 59 |
{
|
| 60 |
"input": "What is the total revenue for each category?",
|
|
@@ -78,12 +76,15 @@ for ex in few_shot_examples:
|
|
| 78 |
# Prepare table information
|
| 79 |
table_info = f"Table: {table_name}\nColumns: {', '.join(valid_columns)}"
|
| 80 |
|
|
|
|
|
|
|
|
|
|
| 81 |
# Step 4: Define the prompt template
|
| 82 |
system_message = """
|
| 83 |
You are an expert data analyst who can convert natural language questions into SQL queries.
|
| 84 |
|
| 85 |
-
|
| 86 |
-
{
|
| 87 |
|
| 88 |
Follow these guidelines:
|
| 89 |
1. Only use the columns and tables provided.
|
|
@@ -104,19 +105,16 @@ Question: {input}
|
|
| 104 |
{agent_scratchpad}
|
| 105 |
"""
|
| 106 |
|
| 107 |
-
# Initialize the LLM
|
| 108 |
-
llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
|
| 109 |
-
|
| 110 |
# Step 5: Create the agent
|
| 111 |
toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
|
| 112 |
tools = toolkit.get_tools()
|
| 113 |
-
tool_names = [tool.name for tool in tools]
|
| 114 |
tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
| 115 |
|
| 116 |
# Create the prompt
|
| 117 |
agent_prompt = PromptTemplate(
|
| 118 |
template=system_message,
|
| 119 |
-
input_variables=["input", "agent_scratchpad", "table_info", "few_shot_examples", "
|
| 120 |
)
|
| 121 |
|
| 122 |
# Create the agent
|
|
@@ -146,14 +144,15 @@ def process_input():
|
|
| 146 |
table_info=table_info,
|
| 147 |
few_shot_examples=few_shot_str,
|
| 148 |
agent_scratchpad="",
|
| 149 |
-
|
|
|
|
| 150 |
)
|
| 151 |
|
| 152 |
# Extract the SQL query from the agent's response
|
| 153 |
sql_query = response.strip()
|
| 154 |
logging.info(f"Generated SQL Query: {sql_query}")
|
| 155 |
|
| 156 |
-
#
|
| 157 |
try:
|
| 158 |
result = pd.read_sql_query(sql_query, conn)
|
| 159 |
|
|
@@ -161,12 +160,12 @@ def process_input():
|
|
| 161 |
assistant_response = "The query returned no results. Please try a different question."
|
| 162 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
| 163 |
else:
|
| 164 |
-
#
|
| 165 |
result_display = result.head(10)
|
| 166 |
st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
|
| 167 |
st.session_state.history.append({"role": "assistant", "content": result_display})
|
| 168 |
|
| 169 |
-
# Generate insights
|
| 170 |
insights_template = """
|
| 171 |
You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
|
| 172 |
|
|
@@ -183,7 +182,7 @@ def process_input():
|
|
| 183 |
result_str = result_display.to_string(index=False)
|
| 184 |
insights = insights_chain.run({'question': user_prompt, 'result': result_str})
|
| 185 |
|
| 186 |
-
# Append
|
| 187 |
st.session_state.history.append({"role": "assistant", "content": insights})
|
| 188 |
except Exception as e:
|
| 189 |
logging.error(f"An error occurred during SQL execution: {e}")
|
|
@@ -194,10 +193,10 @@ def process_input():
|
|
| 194 |
assistant_response = f"Error: {e}"
|
| 195 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
| 196 |
|
| 197 |
-
# Reset
|
| 198 |
st.session_state['user_input'] = ''
|
| 199 |
|
| 200 |
-
# Step 7: Display
|
| 201 |
for message in st.session_state.history:
|
| 202 |
if message['role'] == 'user':
|
| 203 |
st.markdown(f"**User:** {message['content']}")
|
|
@@ -208,5 +207,5 @@ for message in st.session_state.history:
|
|
| 208 |
else:
|
| 209 |
st.markdown(f"**Assistant:** {message['content']}")
|
| 210 |
|
| 211 |
-
#
|
| 212 |
st.text_input("Enter your message:", key='user_input', on_change=process_input)
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
import sqlite3
|
| 5 |
import logging
|
| 6 |
+
from langchain.agents import create_sql_agent
|
| 7 |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
| 8 |
from langchain.llms import OpenAI
|
| 9 |
from langchain.sql_database import SQLDatabase
|
| 10 |
+
from langchain.prompts import PromptTemplate
|
|
|
|
|
|
|
| 11 |
from langchain.evaluation import load_evaluator
|
| 12 |
|
| 13 |
# Initialize logging
|
|
|
|
| 17 |
if 'history' not in st.session_state:
|
| 18 |
st.session_state.history = []
|
| 19 |
|
| 20 |
+
# OpenAI API key
|
| 21 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 22 |
|
| 23 |
# Check if the API key is set
|
|
|
|
| 31 |
|
| 32 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
| 33 |
if csv_file is None:
|
| 34 |
+
data = pd.read_csv("default_data.csv") # Ensure this file exists
|
| 35 |
st.write("Using default_data.csv file.")
|
| 36 |
table_name = "default_table"
|
| 37 |
else:
|
|
|
|
| 40 |
st.write(f"Data Preview ({csv_file.name}):")
|
| 41 |
st.dataframe(data.head())
|
| 42 |
|
| 43 |
+
# Step 2: Load CSV data into SQLite database
|
| 44 |
db_file = 'my_database.db'
|
| 45 |
conn = sqlite3.connect(db_file)
|
| 46 |
data.to_sql(table_name, conn, index=False, if_exists='replace')
|
| 47 |
|
| 48 |
+
# SQL table metadata
|
| 49 |
valid_columns = list(data.columns)
|
| 50 |
st.write(f"Valid columns: {valid_columns}")
|
| 51 |
|
| 52 |
+
# Create SQLDatabase instance
|
| 53 |
engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
|
| 54 |
|
| 55 |
+
# Step 3: Define few-shot examples
|
| 56 |
few_shot_examples = [
|
| 57 |
{
|
| 58 |
"input": "What is the total revenue for each category?",
|
|
|
|
| 76 |
# Prepare table information
|
| 77 |
table_info = f"Table: {table_name}\nColumns: {', '.join(valid_columns)}"
|
| 78 |
|
| 79 |
+
# Initialize the LLM
|
| 80 |
+
llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
|
| 81 |
+
|
| 82 |
# Step 4: Define the prompt template
|
| 83 |
system_message = """
|
| 84 |
You are an expert data analyst who can convert natural language questions into SQL queries.
|
| 85 |
|
| 86 |
+
You have access to the following tools:
|
| 87 |
+
{tools}
|
| 88 |
|
| 89 |
Follow these guidelines:
|
| 90 |
1. Only use the columns and tables provided.
|
|
|
|
| 105 |
{agent_scratchpad}
|
| 106 |
"""
|
| 107 |
|
|
|
|
|
|
|
|
|
|
| 108 |
# Step 5: Create the agent
|
| 109 |
toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
|
| 110 |
tools = toolkit.get_tools()
|
| 111 |
+
tool_names = ", ".join([tool.name for tool in tools])
|
| 112 |
tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
| 113 |
|
| 114 |
# Create the prompt
|
| 115 |
agent_prompt = PromptTemplate(
|
| 116 |
template=system_message,
|
| 117 |
+
input_variables=["input", "agent_scratchpad", "table_info", "few_shot_examples", "tools", "tool_names"]
|
| 118 |
)
|
| 119 |
|
| 120 |
# Create the agent
|
|
|
|
| 144 |
table_info=table_info,
|
| 145 |
few_shot_examples=few_shot_str,
|
| 146 |
agent_scratchpad="",
|
| 147 |
+
tools=tool_descriptions,
|
| 148 |
+
tool_names=tool_names
|
| 149 |
)
|
| 150 |
|
| 151 |
# Extract the SQL query from the agent's response
|
| 152 |
sql_query = response.strip()
|
| 153 |
logging.info(f"Generated SQL Query: {sql_query}")
|
| 154 |
|
| 155 |
+
# Execute SQL query
|
| 156 |
try:
|
| 157 |
result = pd.read_sql_query(sql_query, conn)
|
| 158 |
|
|
|
|
| 160 |
assistant_response = "The query returned no results. Please try a different question."
|
| 161 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
| 162 |
else:
|
| 163 |
+
# Display results
|
| 164 |
result_display = result.head(10)
|
| 165 |
st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
|
| 166 |
st.session_state.history.append({"role": "assistant", "content": result_display})
|
| 167 |
|
| 168 |
+
# Generate insights
|
| 169 |
insights_template = """
|
| 170 |
You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
|
| 171 |
|
|
|
|
| 182 |
result_str = result_display.to_string(index=False)
|
| 183 |
insights = insights_chain.run({'question': user_prompt, 'result': result_str})
|
| 184 |
|
| 185 |
+
# Append insights to history
|
| 186 |
st.session_state.history.append({"role": "assistant", "content": insights})
|
| 187 |
except Exception as e:
|
| 188 |
logging.error(f"An error occurred during SQL execution: {e}")
|
|
|
|
| 193 |
assistant_response = f"Error: {e}"
|
| 194 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
| 195 |
|
| 196 |
+
# Reset user input
|
| 197 |
st.session_state['user_input'] = ''
|
| 198 |
|
| 199 |
+
# Step 7: Display conversation history
|
| 200 |
for message in st.session_state.history:
|
| 201 |
if message['role'] == 'user':
|
| 202 |
st.markdown(f"**User:** {message['content']}")
|
|
|
|
| 207 |
else:
|
| 208 |
st.markdown(f"**Assistant:** {message['content']}")
|
| 209 |
|
| 210 |
+
# Input field
|
| 211 |
st.text_input("Enter your message:", key='user_input', on_change=process_input)
|