Update tools/sql_tool.py
Browse files- tools/sql_tool.py +63 -10
tools/sql_tool.py
CHANGED
|
@@ -247,22 +247,64 @@ class SQLTool:
|
|
| 247 |
def _nl_to_sql(self, message: str) -> str:
|
| 248 |
"""
|
| 249 |
Convert natural language to SQL query.
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
"""
|
| 252 |
m = message.lower()
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
# If it's already SQL, return as-is (after validation)
|
| 255 |
if re.match(r'^\s*select\s', m, re.IGNORECASE):
|
| 256 |
return message.strip()
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# Template-based generation (customize for your schema)
|
| 259 |
if "avg" in m or "average" in m:
|
| 260 |
if "by month" in m or "monthly" in m:
|
| 261 |
-
return """
|
| 262 |
SELECT
|
| 263 |
DATE_TRUNC('month', date_col) AS month,
|
| 264 |
AVG(metric_col) AS avg_metric
|
| 265 |
-
FROM
|
| 266 |
GROUP BY 1
|
| 267 |
ORDER BY 1 DESC
|
| 268 |
LIMIT 100;
|
|
@@ -274,27 +316,38 @@ LIMIT 100;
|
|
| 274 |
limit = match.group(1) if match else "10"
|
| 275 |
return f"""
|
| 276 |
SELECT *
|
| 277 |
-
FROM
|
| 278 |
ORDER BY metric_col DESC
|
| 279 |
LIMIT {limit};
|
| 280 |
"""
|
| 281 |
|
| 282 |
if "count" in m:
|
| 283 |
-
return """
|
| 284 |
SELECT
|
| 285 |
category_col,
|
| 286 |
COUNT(*) AS count
|
| 287 |
-
FROM
|
| 288 |
GROUP BY 1
|
| 289 |
ORDER BY 2 DESC
|
| 290 |
LIMIT 100;
|
| 291 |
"""
|
| 292 |
|
| 293 |
-
# Default fallback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
return """
|
| 295 |
-
SELECT
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
"""
|
| 299 |
|
| 300 |
def run(self, message: str) -> pd.DataFrame:
|
|
|
|
| 247 |
def _nl_to_sql(self, message: str) -> str:
|
| 248 |
"""
|
| 249 |
Convert natural language to SQL query.
|
| 250 |
+
|
| 251 |
+
IMPORTANT: This is a simple heuristic template system.
|
| 252 |
+
For production, either:
|
| 253 |
+
1. Replace table/column names with your actual schema, OR
|
| 254 |
+
2. Integrate a proper NL2SQL model (e.g., T5, CodeGen, GPT), OR
|
| 255 |
+
3. Have users write SQL directly
|
| 256 |
+
|
| 257 |
+
To customize: Set these environment variables or edit the code:
|
| 258 |
+
- SQL_DEFAULT_SCHEMA (default: "analytics")
|
| 259 |
+
- SQL_DEFAULT_TABLE (default: "fact_table")
|
| 260 |
"""
|
| 261 |
m = message.lower()
|
| 262 |
|
| 263 |
+
# Get configurable defaults
|
| 264 |
+
default_schema = os.getenv("SQL_DEFAULT_SCHEMA", "analytics")
|
| 265 |
+
default_table = os.getenv("SQL_DEFAULT_TABLE", "fact_table")
|
| 266 |
+
full_table = f"{default_schema}.{default_table}"
|
| 267 |
+
|
| 268 |
# If it's already SQL, return as-is (after validation)
|
| 269 |
if re.match(r'^\s*select\s', m, re.IGNORECASE):
|
| 270 |
return message.strip()
|
| 271 |
|
| 272 |
+
# Special keyword: show tables/schemas
|
| 273 |
+
if any(keyword in m for keyword in ["show tables", "list tables", "available tables", "what tables"]):
|
| 274 |
+
return """
|
| 275 |
+
SELECT table_schema, table_name, table_type
|
| 276 |
+
FROM information_schema.tables
|
| 277 |
+
WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
|
| 278 |
+
ORDER BY table_schema, table_name
|
| 279 |
+
LIMIT 100;
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
if any(keyword in m for keyword in ["show schemas", "list schemas", "available schemas"]):
|
| 283 |
+
return """
|
| 284 |
+
SELECT DISTINCT table_schema
|
| 285 |
+
FROM information_schema.tables
|
| 286 |
+
WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
|
| 287 |
+
ORDER BY table_schema;
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
if "show columns" in m or "describe table" in m or "table structure" in m:
|
| 291 |
+
# Try to extract table name from message
|
| 292 |
+
return f"""
|
| 293 |
+
SELECT column_name, data_type, is_nullable
|
| 294 |
+
FROM information_schema.columns
|
| 295 |
+
WHERE table_schema = '{default_schema}'
|
| 296 |
+
ORDER BY ordinal_position
|
| 297 |
+
LIMIT 100;
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
# Template-based generation (customize for your schema)
|
| 301 |
if "avg" in m or "average" in m:
|
| 302 |
if "by month" in m or "monthly" in m:
|
| 303 |
+
return f"""
|
| 304 |
SELECT
|
| 305 |
DATE_TRUNC('month', date_col) AS month,
|
| 306 |
AVG(metric_col) AS avg_metric
|
| 307 |
+
FROM {full_table}
|
| 308 |
GROUP BY 1
|
| 309 |
ORDER BY 1 DESC
|
| 310 |
LIMIT 100;
|
|
|
|
| 316 |
limit = match.group(1) if match else "10"
|
| 317 |
return f"""
|
| 318 |
SELECT *
|
| 319 |
+
FROM {full_table}
|
| 320 |
ORDER BY metric_col DESC
|
| 321 |
LIMIT {limit};
|
| 322 |
"""
|
| 323 |
|
| 324 |
if "count" in m:
|
| 325 |
+
return f"""
|
| 326 |
SELECT
|
| 327 |
category_col,
|
| 328 |
COUNT(*) AS count
|
| 329 |
+
FROM {full_table}
|
| 330 |
GROUP BY 1
|
| 331 |
ORDER BY 2 DESC
|
| 332 |
LIMIT 100;
|
| 333 |
"""
|
| 334 |
|
| 335 |
+
# Default fallback - show available tables instead of failing
|
| 336 |
+
logger.warning(
|
| 337 |
+
f"Could not generate specific SQL for query: '{message}'. "
|
| 338 |
+
f"Returning list of available tables. "
|
| 339 |
+
f"Configure SQL_DEFAULT_SCHEMA and SQL_DEFAULT_TABLE or write SQL directly."
|
| 340 |
+
)
|
| 341 |
return """
|
| 342 |
+
SELECT
|
| 343 |
+
table_schema,
|
| 344 |
+
table_name,
|
| 345 |
+
table_type,
|
| 346 |
+
'Run: SELECT * FROM ' || table_schema || '.' || table_name || ' LIMIT 5' as example_query
|
| 347 |
+
FROM information_schema.tables
|
| 348 |
+
WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
|
| 349 |
+
ORDER BY table_schema, table_name
|
| 350 |
+
LIMIT 50;
|
| 351 |
"""
|
| 352 |
|
| 353 |
def run(self, message: str) -> pd.DataFrame:
|