AshenH commited on
Commit
d861fbf
·
verified ·
1 Parent(s): 081c73d

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +63 -10
tools/sql_tool.py CHANGED
@@ -247,22 +247,64 @@ class SQLTool:
247
  def _nl_to_sql(self, message: str) -> str:
248
  """
249
  Convert natural language to SQL query.
250
- This is a simple heuristic - replace with proper NL2SQL model for production.
 
 
 
 
 
 
 
 
 
251
  """
252
  m = message.lower()
253
 
 
 
 
 
 
254
  # If it's already SQL, return as-is (after validation)
255
  if re.match(r'^\s*select\s', m, re.IGNORECASE):
256
  return message.strip()
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Template-based generation (customize for your schema)
259
  if "avg" in m or "average" in m:
260
  if "by month" in m or "monthly" in m:
261
- return """
262
  SELECT
263
  DATE_TRUNC('month', date_col) AS month,
264
  AVG(metric_col) AS avg_metric
265
- FROM analytics.fact_table
266
  GROUP BY 1
267
  ORDER BY 1 DESC
268
  LIMIT 100;
@@ -274,27 +316,38 @@ LIMIT 100;
274
  limit = match.group(1) if match else "10"
275
  return f"""
276
  SELECT *
277
- FROM analytics.fact_table
278
  ORDER BY metric_col DESC
279
  LIMIT {limit};
280
  """
281
 
282
  if "count" in m:
283
- return """
284
  SELECT
285
  category_col,
286
  COUNT(*) AS count
287
- FROM analytics.fact_table
288
  GROUP BY 1
289
  ORDER BY 2 DESC
290
  LIMIT 100;
291
  """
292
 
293
- # Default fallback
 
 
 
 
 
294
  return """
295
- SELECT *
296
- FROM analytics.fact_table
297
- LIMIT 100;
 
 
 
 
 
 
298
  """
299
 
300
  def run(self, message: str) -> pd.DataFrame:
 
247
  def _nl_to_sql(self, message: str) -> str:
248
  """
249
  Convert natural language to SQL query.
250
+
251
+ IMPORTANT: This is a simple heuristic template system.
252
+ For production, either:
253
+ 1. Replace table/column names with your actual schema, OR
254
+ 2. Integrate a proper NL2SQL model (e.g., T5, CodeGen, GPT), OR
255
+ 3. Have users write SQL directly
256
+
257
+ To customize: Set these environment variables or edit the code:
258
+ - SQL_DEFAULT_SCHEMA (default: "analytics")
259
+ - SQL_DEFAULT_TABLE (default: "fact_table")
260
  """
261
  m = message.lower()
262
 
263
+ # Get configurable defaults
264
+ default_schema = os.getenv("SQL_DEFAULT_SCHEMA", "analytics")
265
+ default_table = os.getenv("SQL_DEFAULT_TABLE", "fact_table")
266
+ full_table = f"{default_schema}.{default_table}"
267
+
268
  # If it's already SQL, return as-is (after validation)
269
  if re.match(r'^\s*select\s', m, re.IGNORECASE):
270
  return message.strip()
271
 
272
+ # Special keyword: show tables/schemas
273
+ if any(keyword in m for keyword in ["show tables", "list tables", "available tables", "what tables"]):
274
+ return """
275
+ SELECT table_schema, table_name, table_type
276
+ FROM information_schema.tables
277
+ WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
278
+ ORDER BY table_schema, table_name
279
+ LIMIT 100;
280
+ """
281
+
282
+ if any(keyword in m for keyword in ["show schemas", "list schemas", "available schemas"]):
283
+ return """
284
+ SELECT DISTINCT table_schema
285
+ FROM information_schema.tables
286
+ WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
287
+ ORDER BY table_schema;
288
+ """
289
+
290
+ if "show columns" in m or "describe table" in m or "table structure" in m:
291
+ # Try to extract table name from message
292
+ return f"""
293
+ SELECT column_name, data_type, is_nullable
294
+ FROM information_schema.columns
295
+ WHERE table_schema = '{default_schema}'
296
+ ORDER BY ordinal_position
297
+ LIMIT 100;
298
+ """
299
+
300
  # Template-based generation (customize for your schema)
301
  if "avg" in m or "average" in m:
302
  if "by month" in m or "monthly" in m:
303
+ return f"""
304
  SELECT
305
  DATE_TRUNC('month', date_col) AS month,
306
  AVG(metric_col) AS avg_metric
307
+ FROM {full_table}
308
  GROUP BY 1
309
  ORDER BY 1 DESC
310
  LIMIT 100;
 
316
  limit = match.group(1) if match else "10"
317
  return f"""
318
  SELECT *
319
+ FROM {full_table}
320
  ORDER BY metric_col DESC
321
  LIMIT {limit};
322
  """
323
 
324
  if "count" in m:
325
+ return f"""
326
  SELECT
327
  category_col,
328
  COUNT(*) AS count
329
+ FROM {full_table}
330
  GROUP BY 1
331
  ORDER BY 2 DESC
332
  LIMIT 100;
333
  """
334
 
335
+ # Default fallback - show available tables instead of failing
336
+ logger.warning(
337
+ f"Could not generate specific SQL for query: '{message}'. "
338
+ f"Returning list of available tables. "
339
+ f"Configure SQL_DEFAULT_SCHEMA and SQL_DEFAULT_TABLE or write SQL directly."
340
+ )
341
  return """
342
+ SELECT
343
+ table_schema,
344
+ table_name,
345
+ table_type,
346
+ 'Run: SELECT * FROM ' || table_schema || '.' || table_name || ' LIMIT 5' as example_query
347
+ FROM information_schema.tables
348
+ WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
349
+ ORDER BY table_schema, table_name
350
+ LIMIT 50;
351
  """
352
 
353
  def run(self, message: str) -> pd.DataFrame: