AshenH commited on
Commit
0d9239a
·
verified ·
1 Parent(s): 47a613b

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +148 -110
tools/sql_tool.py CHANGED
@@ -1,117 +1,155 @@
1
- # app.py
2
  import os
3
- import pandas as pd
4
- import gradio as gr
5
 
6
- from tools.sql_tool import SQLTool
7
- from tools.ts_preprocess import build_timeseries
8
 
9
- # ==========================================================
10
- # CONFIG
11
- # ==========================================================
12
  DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
 
 
13
  DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "my_db")
14
  DEFAULT_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v")
15
 
16
- sql_tool = SQLTool(DUCKDB_PATH)
17
-
18
- INTRO = f"""
19
- ### ALM LLM Demo
20
-
21
- Connected to **DuckDB** at `{DUCKDB_PATH}` using table **{DEFAULT_SCHEMA}.{DEFAULT_TABLE}**.
22
-
23
- **Try:**
24
- - *"show me the top 10 fds by portfolio value"*
25
- - *"top 10 assets by portfolio value"*
26
- - *"sum portfolio value by currency"*
27
- """
28
-
29
- # ==========================================================
30
- # BACKEND HANDLERS
31
- # ==========================================================
32
- def run_nl(nl_query: str):
33
- """Handle natural-language queries."""
34
- if not nl_query or not nl_query.strip():
35
- return pd.DataFrame(), "", "Please enter a query.", pd.DataFrame(), pd.DataFrame()
36
-
37
- try:
38
- df, sql, why = sql_tool.query_from_nl(nl_query)
39
- except Exception as e:
40
- return pd.DataFrame(), "", f"Error: {e}", pd.DataFrame(), pd.DataFrame()
41
-
42
- try:
43
- cf, gap = build_timeseries(df)
44
- except Exception:
45
- cf, gap = pd.DataFrame(), pd.DataFrame()
46
-
47
- return df, sql.strip(), why, cf, gap
48
-
49
-
50
- def run_sql(sql_text: str):
51
- """Handle raw SQL execution."""
52
- if not sql_text or not sql_text.strip():
53
- return pd.DataFrame(), "Please paste a SQL statement.", pd.DataFrame(), pd.DataFrame()
54
-
55
- try:
56
- df = sql_tool.run_sql(sql_text)
57
- except Exception as e:
58
- return pd.DataFrame(), f"Error: {e}", pd.DataFrame(), pd.DataFrame()
59
-
60
- try:
61
- cf, gap = build_timeseries(df)
62
- except Exception:
63
- cf, gap = pd.DataFrame(), pd.DataFrame()
64
-
65
- return df, "OK", cf, gap
66
-
67
-
68
- # ==========================================================
69
- # GRADIO UI
70
- # ==========================================================
71
- with gr.Blocks(title="ALM LLM") as demo:
72
- gr.Markdown(INTRO)
73
-
74
- # ---- Tab 1: Natural language ----
75
- with gr.Tab("Ask in Natural Language"):
76
- nl = gr.Textbox(
77
- label="Ask a question",
78
- placeholder="e.g., show me the top 10 fds by portfolio value",
79
- lines=2,
80
- )
81
- btn = gr.Button("Run")
82
- sql_out = gr.Textbox(label="Generated SQL", interactive=False)
83
- why_out = gr.Textbox(label="Reasoning", interactive=False)
84
- df_out = gr.Dataframe(label="Query Result", interactive=True)
85
- cf_out = gr.Dataframe(label="Projected Cash-Flows (if applicable)", interactive=True)
86
- gap_out = gr.Dataframe(label="Liquidity Gap (monthly)", interactive=True)
87
-
88
- btn.click(
89
- fn=run_nl,
90
- inputs=[nl],
91
- outputs=[df_out, sql_out, why_out, cf_out, gap_out],
92
- )
93
-
94
- # ---- Tab 2: Raw SQL ----
95
- with gr.Tab("Run Raw SQL"):
96
- sql_in = gr.Code(
97
- label="SQL",
98
- language="sql",
99
- value=f"SELECT * FROM {DEFAULT_SCHEMA}.{DEFAULT_TABLE} LIMIT 20;",
100
- )
101
- btn2 = gr.Button("Execute")
102
- df2 = gr.Dataframe(label="Result", interactive=True)
103
- status = gr.Textbox(label="Status", interactive=False)
104
- cf2 = gr.Dataframe(label="Projected Cash-Flows (if applicable)", interactive=True)
105
- gap2 = gr.Dataframe(label="Liquidity Gap (monthly)", interactive=True)
106
-
107
- btn2.click(
108
- fn=run_sql,
109
- inputs=[sql_in],
110
- outputs=[df2, status, cf2, gap2],
111
- )
112
-
113
- # ==========================================================
114
- # LAUNCH
115
- # ==========================================================
116
- if __name__ == "__main__":
117
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tools/sql_tool.py
2
  import os
3
+ import re
4
+ from typing import Optional, Tuple
5
 
6
+ import duckdb
 
7
 
8
+ # DuckDB file path (can be overridden in Space settings)
 
 
9
  DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
10
+
11
+ # Default schema/table -> your path my_db.masterdataset_v
12
  DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "my_db")
13
  DEFAULT_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v")
14
 
15
+
16
+ def _full_table(schema: Optional[str] = None, table: Optional[str] = None) -> str:
17
+ schema = schema or DEFAULT_SCHEMA
18
+ table = table or DEFAULT_TABLE
19
+ return f"{schema}.{table}"
20
+
21
+
22
+ class SQLTool:
23
+ """
24
+ Minimal NL→SQL helper wired to my_db.masterdataset_v with a DuckDB runner.
25
+ """
26
+
27
+ def __init__(self, db_path: Optional[str] = None):
28
+ self.db_path = db_path or DUCKDB_PATH
29
+ self.con = duckdb.connect(self.db_path)
30
+
31
+ # -------------------------
32
+ # SQL Runner
33
+ # -------------------------
34
+ def run_sql(self, sql: str):
35
+ return self.con.execute(sql).df()
36
+
37
+ # -------------------------
38
+ # NL SQL
39
+ # -------------------------
40
+ def _nl_to_sql(
41
+ self, message: str, schema: Optional[str] = None, table: Optional[str] = None
42
+ ) -> Tuple[str, str]:
43
+ """
44
+ Returns (sql, rationale). Small template library covering common queries.
45
+ Falls back to a filtered SELECT or a sample.
46
+ """
47
+ full_table = _full_table(schema, table)
48
+ m = (message or "").strip().lower()
49
+
50
+ def has_any(txt, words):
51
+ return any(w in txt for w in words)
52
+
53
+ # Extract "top N"
54
+ limit = None
55
+ m_top = re.search(r"\btop\s+(\d+)", m)
56
+ if m_top:
57
+ limit = int(m_top.group(1))
58
+
59
+ # 1) Top N FDs by Portfolio_value
60
+ if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(
61
+ m, ["top", "largest", "biggest"]
62
+ ) and has_any(m, ["portfolio value", "portfolio_value"]):
63
+ n = limit or 10
64
+ sql = f"""
65
+ SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
66
+ FROM {full_table}
67
+ WHERE lower(product) = 'fd'
68
+ ORDER BY Portfolio_value DESC
69
+ LIMIT {n};
70
+ """
71
+ why = f"Top {n} fixed deposits by Portfolio_value from {full_table}"
72
+ return sql, why
73
+
74
+ # 2) Top N Assets by Portfolio_value
75
+ if has_any(m, ["asset", "loan", "advances"]) and has_any(
76
+ m, ["top", "largest", "biggest"]
77
+ ) and has_any(m, ["portfolio value", "portfolio_value"]):
78
+ n = limit or 10
79
+ sql = f"""
80
+ SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
81
+ FROM {full_table}
82
+ WHERE lower(product) = 'assets'
83
+ ORDER BY Portfolio_value DESC
84
+ LIMIT {n};
85
+ """
86
+ why = f"Top {n} assets by Portfolio_value from {full_table}"
87
+ return sql, why
88
+
89
+ # 3) Aggregate (SUM/AVG) by segment or currency
90
+ if has_any(m, ["sum", "total", "avg", "average"]) and has_any(
91
+ m, ["segment", "currency"]
92
+ ):
93
+ agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG"
94
+ dim = "segments" if "segment" in m else "currency"
95
+ sql = f"""
96
+ SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value
97
+ FROM {full_table}
98
+ GROUP BY 1
99
+ ORDER BY 2 DESC;
100
+ """
101
+ why = f"{agg} Portfolio_value grouped by {dim} from {full_table}"
102
+ return sql, why
103
+
104
+ # 4) Generic filters
105
+ product = None
106
+ if "fd" in m or "deposit" in m:
107
+ product = "fd"
108
+ elif "asset" in m or "loan" in m or "advance" in m:
109
+ product = "assets"
110
+
111
+ parts = [f"SELECT * FROM {full_table} WHERE 1=1"]
112
+ why_parts = [f"Filtered rows from {full_table}"]
113
+
114
+ if product:
115
+ parts.append(f"AND lower(product) = '{product}'")
116
+ why_parts.append(f"product = {product}")
117
+
118
+ # currency filter like: "in lkr", "currency usd"
119
+ cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m)
120
+ if cur_match:
121
+ cur = cur_match.group(2).upper()
122
+ parts.append(f"AND upper(currency) = '{cur}'")
123
+ why_parts.append(f"currency = {cur}")
124
+
125
+ # segment filter like: "segment retail" or "for corporate"
126
+ seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m)
127
+ if seg_match:
128
+ seg = seg_match.group(2).strip()
129
+ if seg:
130
+ parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'")
131
+ why_parts.append(f"segments like '{seg}'")
132
+
133
+ if limit:
134
+ parts.append(f"LIMIT {limit}")
135
+
136
+ fallback_sql = " ".join(parts) + ";"
137
+ fallback_why = "; ".join(why_parts)
138
+ return fallback_sql, fallback_why
139
+
140
+ # Public helpers
141
+ def query_from_nl(self, message: str):
142
+ sql, why = self._nl_to_sql(message)
143
+ df = self.run_sql(sql)
144
+ return df, sql, why
145
+
146
+ def table_exists(self, schema: Optional[str] = None, table: Optional[str] = None) -> bool:
147
+ schema = schema or DEFAULT_SCHEMA
148
+ table = table or DEFAULT_TABLE
149
+ q = f"""
150
+ SELECT COUNT(*) AS n
151
+ FROM information_schema.tables
152
+ WHERE table_schema = '{schema}' AND table_name = '{table}';
153
+ """
154
+ n = self.con.execute(q).fetchone()[0]
155
+ return n > 0