AshenH commited on
Commit
f69d955
·
verified ·
1 Parent(s): 0d9239a

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +29 -42
tools/sql_tool.py CHANGED
@@ -5,46 +5,44 @@ from typing import Optional, Tuple
5
 
6
  import duckdb
7
 
8
- # DuckDB file path (can be overridden in Space settings)
9
  DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
10
 
11
- # Default schema/table -> your path my_db.masterdataset_v
12
- DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "my_db")
13
- DEFAULT_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v")
14
-
15
-
16
- def _full_table(schema: Optional[str] = None, table: Optional[str] = None) -> str:
 
 
 
 
 
17
  schema = schema or DEFAULT_SCHEMA
18
  table = table or DEFAULT_TABLE
19
- return f"{schema}.{table}"
20
 
21
 
22
  class SQLTool:
23
- """
24
- Minimal NL→SQL helper wired to my_db.masterdataset_v with a DuckDB runner.
25
- """
26
 
27
  def __init__(self, db_path: Optional[str] = None):
28
  self.db_path = db_path or DUCKDB_PATH
29
  self.con = duckdb.connect(self.db_path)
 
30
 
31
- # -------------------------
32
- # SQL Runner
33
- # -------------------------
34
  def run_sql(self, sql: str):
35
  return self.con.execute(sql).df()
36
 
37
- # -------------------------
38
  # NL → SQL
39
- # -------------------------
40
- def _nl_to_sql(
41
- self, message: str, schema: Optional[str] = None, table: Optional[str] = None
42
- ) -> Tuple[str, str]:
43
- """
44
- Returns (sql, rationale). Small template library covering common queries.
45
- Falls back to a filtered SELECT or a sample.
46
- """
47
- full_table = _full_table(schema, table)
48
  m = (message or "").strip().lower()
49
 
50
  def has_any(txt, words):
@@ -56,7 +54,7 @@ class SQLTool:
56
  if m_top:
57
  limit = int(m_top.group(1))
58
 
59
- # 1) Top N FDs by Portfolio_value
60
  if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(
61
  m, ["top", "largest", "biggest"]
62
  ) and has_any(m, ["portfolio value", "portfolio_value"]):
@@ -71,7 +69,7 @@ class SQLTool:
71
  why = f"Top {n} fixed deposits by Portfolio_value from {full_table}"
72
  return sql, why
73
 
74
- # 2) Top N Assets by Portfolio_value
75
  if has_any(m, ["asset", "loan", "advances"]) and has_any(
76
  m, ["top", "largest", "biggest"]
77
  ) and has_any(m, ["portfolio value", "portfolio_value"]):
@@ -86,7 +84,7 @@ class SQLTool:
86
  why = f"Top {n} assets by Portfolio_value from {full_table}"
87
  return sql, why
88
 
89
- # 3) Aggregate (SUM/AVG) by segment or currency
90
  if has_any(m, ["sum", "total", "avg", "average"]) and has_any(
91
  m, ["segment", "currency"]
92
  ):
@@ -101,7 +99,7 @@ class SQLTool:
101
  why = f"{agg} Portfolio_value grouped by {dim} from {full_table}"
102
  return sql, why
103
 
104
- # 4) Generic filters
105
  product = None
106
  if "fd" in m or "deposit" in m:
107
  product = "fd"
@@ -115,14 +113,12 @@ class SQLTool:
115
  parts.append(f"AND lower(product) = '{product}'")
116
  why_parts.append(f"product = {product}")
117
 
118
- # currency filter like: "in lkr", "currency usd"
119
  cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m)
120
  if cur_match:
121
  cur = cur_match.group(2).upper()
122
  parts.append(f"AND upper(currency) = '{cur}'")
123
  why_parts.append(f"currency = {cur}")
124
 
125
- # segment filter like: "segment retail" or "for corporate"
126
  seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m)
127
  if seg_match:
128
  seg = seg_match.group(2).strip()
@@ -137,19 +133,10 @@ class SQLTool:
137
  fallback_why = "; ".join(why_parts)
138
  return fallback_sql, fallback_why
139
 
140
- # Public helpers
 
 
141
  def query_from_nl(self, message: str):
142
  sql, why = self._nl_to_sql(message)
143
  df = self.run_sql(sql)
144
  return df, sql, why
145
-
146
- def table_exists(self, schema: Optional[str] = None, table: Optional[str] = None) -> bool:
147
- schema = schema or DEFAULT_SCHEMA
148
- table = table or DEFAULT_TABLE
149
- q = f"""
150
- SELECT COUNT(*) AS n
151
- FROM information_schema.tables
152
- WHERE table_schema = '{schema}' AND table_name = '{table}';
153
- """
154
- n = self.con.execute(q).fetchone()[0]
155
- return n > 0
 
5
 
6
  import duckdb
7
 
8
+ # DuckDB connection file
9
  DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
10
 
11
+ # Fully qualified schema path confirmed from your server
12
+ # my_db.main.masterdataset_v
13
+ DEFAULT_DB = os.getenv("SQL_DEFAULT_DB", "my_db")
14
+ DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main")
15
+ DEFAULT_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v")
16
+
17
+ def _full_table(db: Optional[str] = None,
18
+ schema: Optional[str] = None,
19
+ table: Optional[str] = None) -> str:
20
+ """Return fully qualified <db>.<schema>.<table>"""
21
+ db = db or DEFAULT_DB
22
  schema = schema or DEFAULT_SCHEMA
23
  table = table or DEFAULT_TABLE
24
+ return f"{db}.{schema}.{table}"
25
 
26
 
27
  class SQLTool:
28
+ """Natural-language → SQL helper for DuckDB"""
 
 
29
 
30
  def __init__(self, db_path: Optional[str] = None):
31
  self.db_path = db_path or DUCKDB_PATH
32
  self.con = duckdb.connect(self.db_path)
33
+ self.full_table = _full_table()
34
 
35
+ # ------------------------------------------------------------
36
+ # Run SQL directly
37
+ # ------------------------------------------------------------
38
  def run_sql(self, sql: str):
39
  return self.con.execute(sql).df()
40
 
41
+ # ------------------------------------------------------------
42
  # NL → SQL
43
+ # ------------------------------------------------------------
44
+ def _nl_to_sql(self, message: str) -> Tuple[str, str]:
45
+ full_table = self.full_table
 
 
 
 
 
 
46
  m = (message or "").strip().lower()
47
 
48
  def has_any(txt, words):
 
54
  if m_top:
55
  limit = int(m_top.group(1))
56
 
57
+ # 1. Top N FDs
58
  if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(
59
  m, ["top", "largest", "biggest"]
60
  ) and has_any(m, ["portfolio value", "portfolio_value"]):
 
69
  why = f"Top {n} fixed deposits by Portfolio_value from {full_table}"
70
  return sql, why
71
 
72
+ # 2. Top N Assets
73
  if has_any(m, ["asset", "loan", "advances"]) and has_any(
74
  m, ["top", "largest", "biggest"]
75
  ) and has_any(m, ["portfolio value", "portfolio_value"]):
 
84
  why = f"Top {n} assets by Portfolio_value from {full_table}"
85
  return sql, why
86
 
87
+ # 3. Aggregate by segment/currency
88
  if has_any(m, ["sum", "total", "avg", "average"]) and has_any(
89
  m, ["segment", "currency"]
90
  ):
 
99
  why = f"{agg} Portfolio_value grouped by {dim} from {full_table}"
100
  return sql, why
101
 
102
+ # 4. Generic filters
103
  product = None
104
  if "fd" in m or "deposit" in m:
105
  product = "fd"
 
113
  parts.append(f"AND lower(product) = '{product}'")
114
  why_parts.append(f"product = {product}")
115
 
 
116
  cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m)
117
  if cur_match:
118
  cur = cur_match.group(2).upper()
119
  parts.append(f"AND upper(currency) = '{cur}'")
120
  why_parts.append(f"currency = {cur}")
121
 
 
122
  seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m)
123
  if seg_match:
124
  seg = seg_match.group(2).strip()
 
133
  fallback_why = "; ".join(why_parts)
134
  return fallback_sql, fallback_why
135
 
136
+ # ------------------------------------------------------------
137
+ # Public wrappers
138
+ # ------------------------------------------------------------
139
  def query_from_nl(self, message: str):
140
  sql, why = self._nl_to_sql(message)
141
  df = self.run_sql(sql)
142
  return df, sql, why