AshenH commited on
Commit
6e66f3a
·
verified ·
1 Parent(s): 0ffc27e

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +112 -138
tools/sql_tool.py CHANGED
@@ -1,143 +1,117 @@
1
- # tools/sql_tool.py
2
  import os
3
- import re
4
- import duckdb
5
- from typing import Optional, Tuple
6
 
7
- DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
 
8
 
9
- # Defaults point to your real table; can be overridden via Space secrets
10
- DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main")
 
 
 
11
  DEFAULT_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v")
12
 
13
- def _full_table(schema: Optional[str] = None, table: Optional[str] = None) -> str:
14
- schema = schema or DEFAULT_SCHEMA
15
- table = table or DEFAULT_TABLE
16
- return f"{schema}.{table}"
17
-
18
- class SQLTool:
19
- """
20
- Minimal NL→SQL helper wired to main.masterdataset_v with a DuckDB runner.
21
- """
22
- def __init__(self, db_path: Optional[str] = None):
23
- self.db_path = db_path or DUCKDB_PATH
24
- self.con = duckdb.connect(self.db_path)
25
-
26
- def run_sql(self, sql: str):
27
- return self.con.execute(sql).df()
28
-
29
- # -------------------------
30
- # NL → SQL
31
- # -------------------------
32
- def _nl_to_sql(self, message: str, schema: Optional[str] = None, table: Optional[str] = None) -> Tuple[str, str]:
33
- """
34
- Returns (sql, rationale). Very small template library covering your common queries.
35
- Falls back to SHOW TABLES if no match.
36
- """
37
- full_table = _full_table(schema, table)
38
- m = message.strip().lower()
39
-
40
- # Common synonyms
41
- def has_any(txt, words):
42
- return any(w in txt for w in words)
43
-
44
- # Extract a "top N"
45
- limit = None
46
- m_top = re.search(r"\btop\s+(\d+)", m)
47
- if m_top:
48
- limit = int(m_top.group(1))
49
-
50
- # 1) Top N FDs by Portfolio_value
51
- if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(m, ["top", "largest", "biggest"]) and has_any(m, ["portfolio value", "portfolio_value"]):
52
- n = limit or 10
53
- sql = f"""
54
- SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
55
- FROM {full_table}
56
- WHERE lower(product) = 'fd'
57
- ORDER BY Portfolio_value DESC
58
- LIMIT {n};
59
- """
60
- why = f"Top {n} fixed deposits by Portfolio_value from {full_table}"
61
- return sql, why
62
-
63
- # 2) Top N Assets by Portfolio_value
64
- if has_any(m, ["asset", "loan", "advances"]) and has_any(m, ["top", "largest", "biggest"]) and has_any(m, ["portfolio value", "portfolio_value"]):
65
- n = limit or 10
66
- sql = f"""
67
- SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
68
- FROM {full_table}
69
- WHERE lower(product) = 'assets'
70
- ORDER BY Portfolio_value DESC
71
- LIMIT {n};
72
- """
73
- why = f"Top {n} assets by Portfolio_value from {full_table}"
74
- return sql, why
75
-
76
- # 3) Aggregate (SUM/AVG) by segment or currency
77
- if has_any(m, ["sum", "total", "avg", "average"]) and has_any(m, ["segment", "currency"]):
78
- agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG"
79
- dim = "segments" if "segment" in m else "currency"
80
- sql = f"""
81
- SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value
82
- FROM {full_table}
83
- GROUP BY 1
84
- ORDER BY 2 DESC;
85
- """
86
- why = f"{agg} Portfolio_value grouped by {dim} from {full_table}"
87
- return sql, why
88
-
89
- # 4) Filter by product, currency, or segment
90
- product = None
91
- if "fd" in m or "deposit" in m:
92
- product = "fd"
93
- elif "asset" in m or "loan" in m or "advance" in m:
94
- product = "assets"
95
-
96
- parts = [f"SELECT * FROM {full_table} WHERE 1=1"]
97
- why_parts = [f"Filtered rows from {full_table}"]
98
-
99
- if product:
100
- parts.append(f"AND lower(product) = '{product}'")
101
- why_parts.append(f"product = {product}")
102
-
103
- # currency filter like: "in lkr", "currency usd"
104
- cur = None
105
- cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m)
106
- if cur_match:
107
- cur = cur_match.group(2).upper()
108
- if cur:
109
- parts.append(f"AND upper(currency) = '{cur}'")
110
- why_parts.append(f"currency = {cur}")
111
-
112
- # segment filter like: "segment retail" or "for corporate"
113
- seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m)
114
- if seg_match:
115
- seg = seg_match.group(2).strip()
116
- if seg:
117
- parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'")
118
- why_parts.append(f"segments like '{seg}'")
119
-
120
- # maybe a limit
121
- if limit:
122
- parts.append(f"LIMIT {limit}")
123
-
124
- fallback_sql = " ".join(parts) + ";"
125
- fallback_why = "; ".join(why_parts)
126
- if fallback_sql:
127
- return fallback_sql, fallback_why
128
-
129
- # 5) Super fallback: show sample rows
130
- return f"SELECT * FROM {full_table} LIMIT 20;", f"Default sample from {full_table}"
131
-
132
- # Public helpers
133
- def query_from_nl(self, message: str):
134
- sql, why = self._nl_to_sql(message)
135
- df = self.run_sql(sql)
136
- return df, sql, why
137
-
138
- def table_exists(self, schema: Optional[str] = None, table: Optional[str] = None) -> bool:
139
- schema = schema or DEFAULT_SCHEMA
140
- table = table or DEFAULT_TABLE
141
- q = f"SELECT COUNT(*) AS n FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name = '{table}';"
142
- n = self.con.execute(q).fetchone()[0]
143
- return n > 0
 
1
+ # app.py
2
  import os
3
+ import pandas as pd
4
+ import gradio as gr
 
5
 
6
+ from tools.sql_tool import SQLTool
7
+ from tools.ts_preprocess import build_timeseries
8
 
9
+ # ==========================================================
10
+ # CONFIG
11
+ # ==========================================================
12
+ DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
13
+ DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "my_db")
14
  DEFAULT_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v")
15
 
16
+ sql_tool = SQLTool(DUCKDB_PATH)
17
+
18
+ INTRO = f"""
19
+ ### ALM LLM — Demo
20
+
21
+ Connected to **DuckDB** at `{DUCKDB_PATH}` using table **{DEFAULT_SCHEMA}.{DEFAULT_TABLE}**.
22
+
23
+ **Try:**
24
+ - *"show me the top 10 fds by portfolio value"*
25
+ - *"top 10 assets by portfolio value"*
26
+ - *"sum portfolio value by currency"*
27
+ """
28
+
29
+ # ==========================================================
30
+ # BACKEND HANDLERS
31
+ # ==========================================================
32
+ def run_nl(nl_query: str):
33
+ """Handle natural-language queries."""
34
+ if not nl_query or not nl_query.strip():
35
+ return pd.DataFrame(), "", "Please enter a query.", pd.DataFrame(), pd.DataFrame()
36
+
37
+ try:
38
+ df, sql, why = sql_tool.query_from_nl(nl_query)
39
+ except Exception as e:
40
+ return pd.DataFrame(), "", f"Error: {e}", pd.DataFrame(), pd.DataFrame()
41
+
42
+ try:
43
+ cf, gap = build_timeseries(df)
44
+ except Exception:
45
+ cf, gap = pd.DataFrame(), pd.DataFrame()
46
+
47
+ return df, sql.strip(), why, cf, gap
48
+
49
+
50
+ def run_sql(sql_text: str):
51
+ """Handle raw SQL execution."""
52
+ if not sql_text or not sql_text.strip():
53
+ return pd.DataFrame(), "Please paste a SQL statement.", pd.DataFrame(), pd.DataFrame()
54
+
55
+ try:
56
+ df = sql_tool.run_sql(sql_text)
57
+ except Exception as e:
58
+ return pd.DataFrame(), f"Error: {e}", pd.DataFrame(), pd.DataFrame()
59
+
60
+ try:
61
+ cf, gap = build_timeseries(df)
62
+ except Exception:
63
+ cf, gap = pd.DataFrame(), pd.DataFrame()
64
+
65
+ return df, "OK", cf, gap
66
+
67
+
68
+ # ==========================================================
69
+ # GRADIO UI
70
+ # ==========================================================
71
+ with gr.Blocks(title="ALM LLM") as demo:
72
+ gr.Markdown(INTRO)
73
+
74
+ # ---- Tab 1: Natural language ----
75
+ with gr.Tab("Ask in Natural Language"):
76
+ nl = gr.Textbox(
77
+ label="Ask a question",
78
+ placeholder="e.g., show me the top 10 fds by portfolio value",
79
+ lines=2,
80
+ )
81
+ btn = gr.Button("Run")
82
+ sql_out = gr.Textbox(label="Generated SQL", interactive=False)
83
+ why_out = gr.Textbox(label="Reasoning", interactive=False)
84
+ df_out = gr.Dataframe(label="Query Result", interactive=True)
85
+ cf_out = gr.Dataframe(label="Projected Cash-Flows (if applicable)", interactive=True)
86
+ gap_out = gr.Dataframe(label="Liquidity Gap (monthly)", interactive=True)
87
+
88
+ btn.click(
89
+ fn=run_nl,
90
+ inputs=[nl],
91
+ outputs=[df_out, sql_out, why_out, cf_out, gap_out],
92
+ )
93
+
94
+ # ---- Tab 2: Raw SQL ----
95
+ with gr.Tab("Run Raw SQL"):
96
+ sql_in = gr.Code(
97
+ label="SQL",
98
+ language="sql",
99
+ value=f"SELECT * FROM {DEFAULT_SCHEMA}.{DEFAULT_TABLE} LIMIT 20;",
100
+ )
101
+ btn2 = gr.Button("Execute")
102
+ df2 = gr.Dataframe(label="Result", interactive=True)
103
+ status = gr.Textbox(label="Status", interactive=False)
104
+ cf2 = gr.Dataframe(label="Projected Cash-Flows (if applicable)", interactive=True)
105
+ gap2 = gr.Dataframe(label="Liquidity Gap (monthly)", interactive=True)
106
+
107
+ btn2.click(
108
+ fn=run_sql,
109
+ inputs=[sql_in],
110
+ outputs=[df2, status, cf2, gap2],
111
+ )
112
+
113
+ # ==========================================================
114
+ # LAUNCH
115
+ # ==========================================================
116
+ if __name__ == "__main__":
117
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))