File size: 5,480 Bytes
2e1969a
f4dc602
 
2e1969a
 
f4dc602
2e1969a
e002acf
2e1969a
 
 
85b8a4e
2e1969a
 
 
 
85b8a4e
f4dc602
6c6d38f
2e1969a
6c6d38f
2e1969a
 
 
 
 
 
 
 
 
 
 
6c6d38f
2e1969a
 
6c6d38f
2e1969a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# tools/sql_tool.py
import os
import re
import duckdb
from typing import Optional, Tuple

DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")

# Defaults point to your real table; can be overridden via Space secrets
DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main")
DEFAULT_TABLE  = os.getenv("SQL_DEFAULT_TABLE",  "masterdataset_v")

def _full_table(schema: Optional[str] = None, table: Optional[str] = None) -> str:
    schema = schema or DEFAULT_SCHEMA
    table  = table  or DEFAULT_TABLE
    return f"{schema}.{table}"

class SQLTool:
    """
    Minimal NL→SQL helper wired to main.masterdataset_v with a DuckDB runner.
    """
    def __init__(self, db_path: Optional[str] = None):
        self.db_path = db_path or DUCKDB_PATH
        self.con = duckdb.connect(self.db_path)

    def run_sql(self, sql: str):
        return self.con.execute(sql).df()

    # -------------------------
    # NL → SQL
    # -------------------------
    def _nl_to_sql(self, message: str, schema: Optional[str] = None, table: Optional[str] = None) -> Tuple[str, str]:
        """
        Returns (sql, rationale). Very small template library covering your common queries.
        Falls back to SHOW TABLES if no match.
        """
        full_table = _full_table(schema, table)
        m = message.strip().lower()

        # Common synonyms
        def has_any(txt, words):
            return any(w in txt for w in words)

        # Extract a "top N"
        limit = None
        m_top = re.search(r"\btop\s+(\d+)", m)
        if m_top:
            limit = int(m_top.group(1))

        # 1) Top N FDs by Portfolio_value
        if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(m, ["top", "largest", "biggest"]) and has_any(m, ["portfolio value", "portfolio_value"]):
            n = limit or 10
            sql = f"""
            SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
            FROM {full_table}
            WHERE lower(product) = 'fd'
            ORDER BY Portfolio_value DESC
            LIMIT {n};
            """
            why = f"Top {n} fixed deposits by Portfolio_value from {full_table}"
            return sql, why

        # 2) Top N Assets by Portfolio_value
        if has_any(m, ["asset", "loan", "advances"]) and has_any(m, ["top", "largest", "biggest"]) and has_any(m, ["portfolio value", "portfolio_value"]):
            n = limit or 10
            sql = f"""
            SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
            FROM {full_table}
            WHERE lower(product) = 'assets'
            ORDER BY Portfolio_value DESC
            LIMIT {n};
            """
            why = f"Top {n} assets by Portfolio_value from {full_table}"
            return sql, why

        # 3) Aggregate (SUM/AVG) by segment or currency
        if has_any(m, ["sum", "total", "avg", "average"]) and has_any(m, ["segment", "currency"]):
            agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG"
            dim = "segments" if "segment" in m else "currency"
            sql = f"""
            SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value
            FROM {full_table}
            GROUP BY 1
            ORDER BY 2 DESC;
            """
            why = f"{agg} Portfolio_value grouped by {dim} from {full_table}"
            return sql, why

        # 4) Filter by product, currency, or segment
        product = None
        if "fd" in m or "deposit" in m:
            product = "fd"
        elif "asset" in m or "loan" in m or "advance" in m:
            product = "assets"

        parts = [f"SELECT * FROM {full_table} WHERE 1=1"]
        why_parts = [f"Filtered rows from {full_table}"]

        if product:
            parts.append(f"AND lower(product) = '{product}'")
            why_parts.append(f"product = {product}")

        # currency filter like: "in lkr", "currency usd"
        cur = None
        cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m)
        if cur_match:
            cur = cur_match.group(2).upper()
        if cur:
            parts.append(f"AND upper(currency) = '{cur}'")
            why_parts.append(f"currency = {cur}")

        # segment filter like: "segment retail" or "for corporate"
        seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m)
        if seg_match:
            seg = seg_match.group(2).strip()
            if seg:
                parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'")
                why_parts.append(f"segments like '{seg}'")

        # maybe a limit
        if limit:
            parts.append(f"LIMIT {limit}")

        fallback_sql = " ".join(parts) + ";"
        fallback_why = "; ".join(why_parts)
        if fallback_sql:
            return fallback_sql, fallback_why

        # 5) Super fallback: show sample rows
        return f"SELECT * FROM {full_table} LIMIT 20;", f"Default sample from {full_table}"

    # Public helpers
    def query_from_nl(self, message: str):
        sql, why = self._nl_to_sql(message)
        df = self.run_sql(sql)
        return df, sql, why

    def table_exists(self, schema: Optional[str] = None, table: Optional[str] = None) -> bool:
        schema = schema or DEFAULT_SCHEMA
        table  = table or DEFAULT_TABLE
        q = f"SELECT COUNT(*) AS n FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name = '{table}';"
        n = self.con.execute(q).fetchone()[0]
        return n > 0