Spaces:
Runtime error
Runtime error
File size: 7,878 Bytes
5a3fcad 2d8f702 5a3fcad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | """
SQL Query Generation Agent
This module implements the SQL agent that generates and executes SQL queries
against SQLite databases created from user-uploaded CSV files.
The agent uses LangGraph's ReAct pattern for reasoning and tool use,
combined with LangChain's SQLDatabaseToolkit for database operations.
Example:
>>> from src.agents import sql_pipeline
>>> result = sql_pipeline(files, "Show top 10 customers by revenue")
"""
import os
from typing import List, Optional, Any
import pandas as pd
from sqlalchemy import create_engine
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langgraph.prebuilt import create_react_agent
from ..prompts import get_sql_prompt
# Default database file path
DEFAULT_DB_PATH = "sqlite:///database.db"
DEFAULT_DB_FILE = "database.db"
class SQLAgent:
"""
Agent for generating and executing SQL queries from natural language.
This agent converts user questions into SQL queries, executes them
against a SQLite database, and returns formatted results.
Attributes:
model: The LLM model for query generation.
database: The SQLDatabase instance to query against.
toolkit: SQLDatabaseToolkit providing database tools.
Example:
>>> agent = SQLAgent(model=azure_llm)
>>> agent.load_files(csv_files)
>>> result = agent.query("What are the top 5 products by sales?")
"""
def __init__(self, model: Any):
"""
Initialize the SQL agent.
Args:
model: The LLM model instance for query generation.
Must be a LangChain-compatible chat model.
"""
self.model = model
self.database: Optional[SQLDatabase] = None
self.toolkit: Optional[SQLDatabaseToolkit] = None
def load_files(self, files: List[Any]) -> bool:
"""
Load CSV files into a SQLite database.
Each CSV file becomes a table in the database, with the table name
derived from the filename (without extension).
Args:
files: List of file objects with .name attribute pointing to CSV paths.
Returns:
bool: True if database creation succeeded, False otherwise.
Example:
>>> agent.load_files([file1, file2])
True
"""
try:
# Create SQLite engine
engine = create_engine(DEFAULT_DB_PATH)
# Load each CSV as a table
with engine.begin() as connection:
for f in files:
# Extract table name from filename
table_name = os.path.splitext(os.path.basename(f.name))[0]
df = pd.read_csv(f.name)
# Write DataFrame to SQL table (replace if exists)
df.to_sql(table_name, connection, if_exists="replace", index=False)
print(f"Loaded table: {table_name}")
# Initialize SQLDatabase wrapper
self.database = SQLDatabase.from_uri(DEFAULT_DB_PATH)
# Create toolkit with database tools
self.toolkit = SQLDatabaseToolkit(db=self.database, llm=self.model)
print(f"Database created: {DEFAULT_DB_FILE}")
return True
except Exception as e:
print(f"Database creation error: {e}")
return False
def query(self, question: str) -> str:
"""
Execute a natural language query against the database.
This method uses LangGraph's ReAct agent to:
1. Inspect the database schema
2. Generate an appropriate SQL query
3. Execute the query
4. Format and return the results
Args:
question: The natural language question about the data.
Returns:
str: The query results formatted in Markdown.
Raises:
ValueError: If database has not been loaded.
Example:
>>> result = agent.query("Show average order value by region")
"""
if self.database is None or self.toolkit is None:
raise ValueError("Database not loaded. Call load_files() first.")
try:
# Get tools from the toolkit
tools = self.toolkit.get_tools()
# Get the SQL system prompt
system_prompt = get_sql_prompt()
# Create ReAct agent for SQL operations
agent_executor = create_react_agent(
self.model,
tools,
prompt=system_prompt
)
# Execute the agent and collect output
output = ""
for step in agent_executor.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
output += step["messages"][-1].content
# Extract final answer section
return self._extract_final_answer(output)
except Exception as e:
return f"SQL agent error: {e}"
def _extract_final_answer(self, output: str) -> str:
"""
Extract the final answer section from the agent output.
Args:
output: The full agent output string.
Returns:
str: The final answer section, or full output if not found.
"""
marker = "### Final answer"
pos = output.find(marker)
if pos != -1:
return output[pos:]
return output
def create_db(files: List[Any]) -> Optional[SQLDatabase]:
"""
Create a SQLite database from uploaded CSV files.
This is a convenience function for creating databases without
instantiating the full SQLAgent class.
Args:
files: List of file objects with .name attribute.
Returns:
SQLDatabase: The created database instance, or None on error.
Example:
>>> db = create_db(uploaded_files)
>>> if db:
... print(db.get_table_names())
"""
try:
engine = create_engine(DEFAULT_DB_PATH)
with engine.begin() as connection:
for f in files:
table_name = os.path.splitext(os.path.basename(f.name))[0]
df = pd.read_csv(f.name)
df.to_sql(table_name, connection, if_exists="replace", index=False)
return SQLDatabase.from_uri(DEFAULT_DB_PATH)
except Exception as e:
print(f"Database error: {e}")
return None
def sql_pipeline(
tables: List[Any],
question: str,
model: Optional[Any] = None
) -> str:
"""
End-to-end pipeline for SQL query generation and execution.
This function handles the complete workflow:
1. Creates a SQLite database from uploaded files
2. Initializes the SQL agent
3. Generates and executes the query
4. Returns formatted results
Args:
tables: List of file objects containing CSV data.
question: The natural language question to answer.
model: Optional LLM model. If None, uses global model.
Returns:
str: The query results in Markdown format.
Example:
>>> result = sql_pipeline(files, "Show monthly sales trends", model)
"""
if model is None:
return "Error: No LLM model provided."
agent = SQLAgent(model=model)
if not agent.load_files(tables):
return "Error: Could not create database from files."
return agent.query(question)
|