AshenH commited on
Commit
3fbd26b
·
verified ·
1 Parent(s): f8cd124

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +113 -18
tools/sql_tool.py CHANGED
@@ -1,41 +1,133 @@
1
  # tools/sql_tool.py
2
  import os
3
  import re
4
- from typing import Optional, Tuple
5
 
6
  import duckdb
 
7
 
8
- # DuckDB connection file
 
 
9
  DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
10
 
11
- # Fully qualified schema path confirmed from your server
12
- # my_db.main.masterdataset_v
13
- DEFAULT_DB = os.getenv("SQL_DEFAULT_DB", "my_db")
14
- DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main")
15
- DEFAULT_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v")
16
 
17
- def _full_table(db: Optional[str] = None,
18
- schema: Optional[str] = None,
19
- table: Optional[str] = None) -> str:
20
- """Return fully qualified <db>.<schema>.<table>"""
21
- db = db or DEFAULT_DB
22
- schema = schema or DEFAULT_SCHEMA
23
- table = table or DEFAULT_TABLE
24
- return f"{db}.{schema}.{table}"
25
 
26
 
27
  class SQLTool:
28
- """Natural-language → SQL helper for DuckDB"""
 
 
 
 
29
 
30
  def __init__(self, db_path: Optional[str] = None):
31
  self.db_path = db_path or DUCKDB_PATH
32
  self.con = duckdb.connect(self.db_path)
33
- self.full_table = _full_table()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # ------------------------------------------------------------
36
  # Run SQL directly
37
  # ------------------------------------------------------------
38
- def run_sql(self, sql: str):
39
  return self.con.execute(sql).df()
40
 
41
  # ------------------------------------------------------------
@@ -140,3 +232,6 @@ class SQLTool:
140
  sql, why = self._nl_to_sql(message)
141
  df = self.run_sql(sql)
142
  return df, sql, why
 
 
 
 
1
  # tools/sql_tool.py
2
  import os
3
  import re
4
+ from typing import Optional, Tuple, List
5
 
6
  import duckdb
7
+ import pandas as pd
8
 
9
+ # ------------------------------------------------------------
10
+ # Connection config
11
+ # ------------------------------------------------------------
12
  DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
13
 
14
+ # If you need to attach a catalog (e.g., MotherDuck), put the full ATTACH here.
15
+ # Example:
16
+ # DUCKDB_ATTACH_SQL=ATTACH 'md:my_db' AS my_db;
17
+ DUCKDB_ATTACH_SQL = os.getenv("DUCKDB_ATTACH_SQL", "").strip()
 
18
 
19
+ # Preferred identifiers (we will fall back automatically if they don't exist)
20
+ PREF_CATALOG = os.getenv("SQL_DEFAULT_DB", "my_db") # catalog (optional)
21
+ PREF_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main") # schema
22
+ PREF_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v") # table
 
 
 
 
23
 
24
 
25
  class SQLTool:
26
+ """
27
+ NL→SQL helper for DuckDB with:
28
+ - optional pre-attach SQL (DUCKDB_ATTACH_SQL)
29
+ - robust table path resolution (tries 3-part → 2-part → 1-part → information_schema scan)
30
+ """
31
 
32
  def __init__(self, db_path: Optional[str] = None):
33
  self.db_path = db_path or DUCKDB_PATH
34
  self.con = duckdb.connect(self.db_path)
35
+
36
+ # Optional: run user-supplied ATTACH (safe no-op if empty)
37
+ if DUCKDB_ATTACH_SQL:
38
+ try:
39
+ self.con.execute(DUCKDB_ATTACH_SQL)
40
+ except Exception as e:
41
+ # Don't crash the app on attach issues; we still try local tables
42
+ print(f"[WARN] DUCKDB_ATTACH_SQL failed: {e}")
43
+
44
+ self.full_table = self._resolve_full_table(PREF_CATALOG, PREF_SCHEMA, PREF_TABLE)
45
+
46
+ # ------------------------------------------------------------
47
+ # Resolution helpers
48
+ # ------------------------------------------------------------
49
+ def _try_probe(self, path: str) -> bool:
50
+ """Return True if SELECT * FROM <path> LIMIT 1 succeeds."""
51
+ try:
52
+ self.con.execute(f"SELECT * FROM {path} LIMIT 1")
53
+ return True
54
+ except Exception:
55
+ return False
56
+
57
+ def _scan_information_schema(self, table_name: str) -> Optional[str]:
58
+ """
59
+ Look for <schema>.<table> (and <catalog>.<schema>.<table> if available)
60
+ in information_schema. Return a best guess path string or None.
61
+ """
62
+ q = """
63
+ SELECT table_catalog, table_schema, table_name
64
+ FROM information_schema.tables
65
+ WHERE lower(table_name) = ?
66
+ ORDER BY table_catalog, table_schema
67
+ """
68
+ rows = self.con.execute(q, [table_name.lower()]).fetchall()
69
+ if not rows:
70
+ return None
71
+
72
+ # Prefer matches in preferred schema/catalog when possible
73
+ # 1) exact catalog+schema
74
+ for cat, sch, t in rows:
75
+ if (cat or "").lower() == (PREF_CATALOG or "").lower() and sch.lower() == PREF_SCHEMA.lower():
76
+ candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}"
77
+ if self._try_probe(candidate):
78
+ return candidate
79
+
80
+ # 2) exact schema (2-part)
81
+ for cat, sch, t in rows:
82
+ if sch.lower() == PREF_SCHEMA.lower():
83
+ candidate = f"{sch}.{t}"
84
+ if self._try_probe(candidate):
85
+ return candidate
86
+
87
+ # 3) first working row (prefer 3-part if catalog present)
88
+ for cat, sch, t in rows:
89
+ candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}"
90
+ if self._try_probe(candidate):
91
+ return candidate
92
+
93
+ return None
94
+
95
+ def _resolve_full_table(self, catalog: Optional[str], schema: Optional[str], table: str) -> str:
96
+ """
97
+ Return a working fully qualified path for the table by trying:
98
+ - <catalog>.<schema>.<table> (3-part)
99
+ - <schema>.<table> (2-part)
100
+ - <table> (1-part)
101
+ - information_schema scan (best effort)
102
+ """
103
+ candidates: List[str] = []
104
+
105
+ if catalog:
106
+ candidates.append(f"{catalog}.{schema}.{table}")
107
+ if schema:
108
+ candidates.append(f"{schema}.{table}")
109
+ candidates.append(table)
110
+
111
+ for path in candidates:
112
+ if self._try_probe(path):
113
+ print(f"[INFO] Using table path: {path}")
114
+ return path
115
+
116
+ # Fallback: scan information_schema
117
+ scanned = self._scan_information_schema(table)
118
+ if scanned:
119
+ print(f"[INFO] Using table path (scanned): {scanned}")
120
+ return scanned
121
+
122
+ # Last resort: keep preferred 3-part (will raise on first query)
123
+ fallback = f"{catalog}.{schema}.{table}" if catalog else f"{schema}.{table}"
124
+ print(f"[WARN] Could not resolve table path; falling back to: {fallback}")
125
+ return fallback
126
 
127
  # ------------------------------------------------------------
128
  # Run SQL directly
129
  # ------------------------------------------------------------
130
+ def run_sql(self, sql: str) -> pd.DataFrame:
131
  return self.con.execute(sql).df()
132
 
133
  # ------------------------------------------------------------
 
232
  sql, why = self._nl_to_sql(message)
233
  df = self.run_sql(sql)
234
  return df, sql, why
235
+
236
+ def get_full_table_path(self) -> str:
237
+ return self.full_table