AshenH commited on
Commit
2a4b15b
·
verified ·
1 Parent(s): 2de875c

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +82 -228
tools/sql_tool.py CHANGED
@@ -1,236 +1,90 @@
1
- # tools/sql_tool.py
2
  import os
3
- import re
4
- from typing import Optional, Tuple, List
5
-
6
  import duckdb
7
  import pandas as pd
 
8
 
9
- # ------------------------------------------------------------
10
- # Connection config
11
- # ------------------------------------------------------------
12
- DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
13
-
14
- # If you need to attach a catalog (e.g., MotherDuck), put the full ATTACH here.
15
- # Example:
16
- DUCKDB_ATTACH_SQL=ATTACH 'md:my_db' AS my_db;
17
-
18
- # Preferred identifiers (we will fall back automatically if they don't exist)
19
- PREF_CATALOG = os.getenv("SQL_DEFAULT_DB", "my_db") # catalog (optional)
20
- PREF_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main") # schema
21
- PREF_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v") # table
22
 
 
23
 
24
- class SQLTool:
25
  """
26
- NL→SQL helper for DuckDB with:
27
- - optional pre-attach SQL (DUCKDB_ATTACH_SQL)
28
- - robust table path resolution (tries 3-part → 2-part → 1-part → information_schema scan)
29
  """
30
-
31
- def __init__(self, db_path: Optional[str] = None):
32
- self.db_path = db_path or DUCKDB_PATH
33
- self.con = duckdb.connect(self.db_path)
34
-
35
- # Optional: run user-supplied ATTACH (safe no-op if empty)
36
- if DUCKDB_ATTACH_SQL:
37
- try:
38
- self.con.execute(DUCKDB_ATTACH_SQL)
39
- except Exception as e:
40
- # Don't crash the app on attach issues; we still try local tables
41
- print(f"[WARN] DUCKDB_ATTACH_SQL failed: {e}")
42
-
43
- self.full_table = self._resolve_full_table(PREF_CATALOG, PREF_SCHEMA, PREF_TABLE)
44
-
45
- # ------------------------------------------------------------
46
- # Resolution helpers
47
- # ------------------------------------------------------------
48
- def _try_probe(self, path: str) -> bool:
49
- """Return True if SELECT * FROM <path> LIMIT 1 succeeds."""
50
- try:
51
- self.con.execute(f"SELECT * FROM {path} LIMIT 1")
52
- return True
53
- except Exception:
54
- return False
55
-
56
- def _scan_information_schema(self, table_name: str) -> Optional[str]:
57
- """
58
- Look for <schema>.<table> (and <catalog>.<schema>.<table> if available)
59
- in information_schema. Return a best guess path string or None.
60
- """
61
- q = """
62
- SELECT table_catalog, table_schema, table_name
63
- FROM information_schema.tables
64
- WHERE lower(table_name) = ?
65
- ORDER BY table_catalog, table_schema
66
- """
67
- rows = self.con.execute(q, [table_name.lower()]).fetchall()
68
- if not rows:
69
- return None
70
-
71
- # Prefer matches in preferred schema/catalog when possible
72
- # 1) exact catalog+schema
73
- for cat, sch, t in rows:
74
- if (cat or "").lower() == (PREF_CATALOG or "").lower() and sch.lower() == PREF_SCHEMA.lower():
75
- candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}"
76
- if self._try_probe(candidate):
77
- return candidate
78
-
79
- # 2) exact schema (2-part)
80
- for cat, sch, t in rows:
81
- if sch.lower() == PREF_SCHEMA.lower():
82
- candidate = f"{sch}.{t}"
83
- if self._try_probe(candidate):
84
- return candidate
85
-
86
- # 3) first working row (prefer 3-part if catalog present)
87
- for cat, sch, t in rows:
88
- candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}"
89
- if self._try_probe(candidate):
90
- return candidate
91
-
92
- return None
93
-
94
- def _resolve_full_table(self, catalog: Optional[str], schema: Optional[str], table: str) -> str:
95
- """
96
- Return a working fully qualified path for the table by trying:
97
- - <catalog>.<schema>.<table> (3-part)
98
- - <schema>.<table> (2-part)
99
- - <table> (1-part)
100
- - information_schema scan (best effort)
101
- """
102
- candidates: List[str] = []
103
-
104
- if catalog:
105
- candidates.append(f"{catalog}.{schema}.{table}")
106
- if schema:
107
- candidates.append(f"{schema}.{table}")
108
- candidates.append(table)
109
-
110
- for path in candidates:
111
- if self._try_probe(path):
112
- print(f"[INFO] Using table path: {path}")
113
- return path
114
-
115
- # Fallback: scan information_schema
116
- scanned = self._scan_information_schema(table)
117
- if scanned:
118
- print(f"[INFO] Using table path (scanned): {scanned}")
119
- return scanned
120
-
121
- # Last resort: keep preferred 3-part (will raise on first query)
122
- fallback = f"{catalog}.{schema}.{table}" if catalog else f"{schema}.{table}"
123
- print(f"[WARN] Could not resolve table path; falling back to: {fallback}")
124
- return fallback
125
-
126
- # ------------------------------------------------------------
127
- # Run SQL directly
128
- # ------------------------------------------------------------
129
- def run_sql(self, sql: str) -> pd.DataFrame:
130
- return self.con.execute(sql).df()
131
-
132
- # ------------------------------------------------------------
133
- # NL → SQL
134
- # ------------------------------------------------------------
135
- def _nl_to_sql(self, message: str) -> Tuple[str, str]:
136
- full_table = self.full_table
137
- m = (message or "").strip().lower()
138
-
139
- def has_any(txt, words):
140
- return any(w in txt for w in words)
141
-
142
- # Extract "top N"
143
- limit = None
144
- m_top = re.search(r"\btop\s+(\d+)", m)
145
- if m_top:
146
- limit = int(m_top.group(1))
147
-
148
- # 1. Top N FDs
149
- if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(
150
- m, ["top", "largest", "biggest"]
151
- ) and has_any(m, ["portfolio value", "portfolio_value"]):
152
- n = limit or 10
153
- sql = f"""
154
- SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
155
- FROM {full_table}
156
- WHERE lower(product) = 'fd'
157
- ORDER BY Portfolio_value DESC
158
- LIMIT {n};
159
- """
160
- why = f"Top {n} fixed deposits by Portfolio_value from {full_table}"
161
- return sql, why
162
-
163
- # 2. Top N Assets
164
- if has_any(m, ["asset", "loan", "advances"]) and has_any(
165
- m, ["top", "largest", "biggest"]
166
- ) and has_any(m, ["portfolio value", "portfolio_value"]):
167
- n = limit or 10
168
- sql = f"""
169
- SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
170
- FROM {full_table}
171
- WHERE lower(product) = 'assets'
172
- ORDER BY Portfolio_value DESC
173
- LIMIT {n};
174
- """
175
- why = f"Top {n} assets by Portfolio_value from {full_table}"
176
- return sql, why
177
-
178
- # 3. Aggregate by segment/currency
179
- if has_any(m, ["sum", "total", "avg", "average"]) and has_any(
180
- m, ["segment", "currency"]
181
- ):
182
- agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG"
183
- dim = "segments" if "segment" in m else "currency"
184
- sql = f"""
185
- SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value
186
- FROM {full_table}
187
- GROUP BY 1
188
- ORDER BY 2 DESC;
189
- """
190
- why = f"{agg} Portfolio_value grouped by {dim} from {full_table}"
191
- return sql, why
192
-
193
- # 4. Generic filters
194
- product = None
195
- if "fd" in m or "deposit" in m:
196
- product = "fd"
197
- elif "asset" in m or "loan" in m or "advance" in m:
198
- product = "assets"
199
-
200
- parts = [f"SELECT * FROM {full_table} WHERE 1=1"]
201
- why_parts = [f"Filtered rows from {full_table}"]
202
-
203
- if product:
204
- parts.append(f"AND lower(product) = '{product}'")
205
- why_parts.append(f"product = {product}")
206
-
207
- cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m)
208
- if cur_match:
209
- cur = cur_match.group(2).upper()
210
- parts.append(f"AND upper(currency) = '{cur}'")
211
- why_parts.append(f"currency = {cur}")
212
-
213
- seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m)
214
- if seg_match:
215
- seg = seg_match.group(2).strip()
216
- if seg:
217
- parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'")
218
- why_parts.append(f"segments like '{seg}'")
219
-
220
- if limit:
221
- parts.append(f"LIMIT {limit}")
222
-
223
- fallback_sql = " ".join(parts) + ";"
224
- fallback_why = "; ".join(why_parts)
225
- return fallback_sql, fallback_why
226
-
227
- # ------------------------------------------------------------
228
- # Public wrappers
229
- # ------------------------------------------------------------
230
- def query_from_nl(self, message: str):
231
- sql, why = self._nl_to_sql(message)
232
- df = self.run_sql(sql)
233
- return df, sql, why
234
-
235
- def get_full_table_path(self) -> str:
236
- return self.full_table
 
1
+ from langchain_core.tools import tool
2
  import os
 
 
 
3
  import duckdb
4
  import pandas as pd
5
+ import warnings
6
 
7
+ # Suppress warnings that might clutter the output
8
+ warnings.filterwarnings("ignore")
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # --- Database Connection Setup ---
11
 
12
+ def get_md_connection() -> duckdb.DuckDBPyConnection:
13
  """
14
+ Establishes a connection to MotherDuck using the MOTHERDUCK_TOKEN environment variable.
 
 
15
  """
16
+ # 1. Get the connection token
17
+ token = os.environ.get('MOTHERDUCK_TOKEN')
18
+ if not token:
19
+ raise ConnectionError(
20
+ "MOTHERDUCK_TOKEN environment variable is not set. "
21
+ "Please ensure it is configured in your secrets to connect to the database."
22
+ )
23
+
24
+ # 2. Connect to the MotherDuck service
25
+ # Note: Replace 'my_db' with your actual MotherDuck database name if necessary,
26
+ # otherwise it connects to the default MotherDuck endpoint.
27
+ conn = duckdb.connect(f'md:?motherduck_token={token}')
28
+ return conn
29
+
30
+ # --- SQL Tools ---
31
+
32
+ @tool
33
+ def run_duckdb_query(query: str) -> str:
34
+ """
35
+ Runs a read-only SQL query against the connected MotherDuck database and returns the results as a string.
36
+ The query must be valid DuckDB SQL. This tool only supports SELECT queries.
37
+ """
38
+ try:
39
+ conn = get_md_connection()
40
+
41
+ # Enforce read-only constraint
42
+ if not query.strip().lower().startswith('select'):
43
+ return "Error: Only read-only SELECT queries are allowed."
44
+
45
+ # Execute the query and fetch the results into a pandas DataFrame
46
+ result_df = conn.execute(query).fetchdf()
47
+
48
+ if result_df.empty:
49
+ return "Query executed successfully, but no rows were returned."
50
+
51
+ # Return the DataFrame as a string
52
+ return result_df.to_string(index=False)
53
+
54
+ except ConnectionError as e:
55
+ return f"Connection Error: {e}"
56
+ except Exception as e:
57
+ return f"DuckDB Query Error: {e}"
58
+ finally:
59
+ # Always close the connection
60
+ if 'conn' in locals() and conn:
61
+ conn.close()
62
+
63
+ @tool
64
+ def get_table_schema(table_name: str = "positions") -> str:
65
+ """
66
+ Returns the schema (column names and data types) for the specified table in the MotherDuck database.
67
+ Defaults to the 'positions' table.
68
+ """
69
+ try:
70
+ conn = get_md_connection()
71
+
72
+ # Use PRAGMA table_info to get the schema details dynamically
73
+ query = f"PRAGMA table_info('{table_name}')"
74
+ schema_df = conn.execute(query).fetchdf()
75
+
76
+ if schema_df.empty:
77
+ return f"Error: Table '{table_name}' not found in the MotherDuck database."
78
+
79
+ # Format the schema into a simple string: name TYPE, name TYPE, ...
80
+ schema_parts = [f"{row['name']} {row['type']}" for index, row in schema_df.iterrows()]
81
+ return ", ".join(schema_parts)
82
+
83
+ except ConnectionError as e:
84
+ return f"Connection Error: {e}"
85
+ except Exception as e:
86
+ return f"DuckDB Schema Error: {e}"
87
+ finally:
88
+ # Always close the connection
89
+ if 'conn' in locals() and conn:
90
+ conn.close()