File size: 5,449 Bytes
0d9239a
f4dc602
0d9239a
 
f4dc602
0d9239a
e002acf
0d9239a
6e66f3a
0d9239a
 
6e66f3a
2e1969a
85b8a4e
0d9239a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# tools/sql_tool.py
import os
import re
from typing import Optional, Tuple

import duckdb

# DuckDB file path (can be overridden in Space settings)
DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")

# Default schema/table -> your path my_db.masterdataset_v
DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "my_db")
DEFAULT_TABLE  = os.getenv("SQL_DEFAULT_TABLE",  "masterdataset_v")


def _full_table(schema: Optional[str] = None, table: Optional[str] = None) -> str:
    schema = schema or DEFAULT_SCHEMA
    table = table or DEFAULT_TABLE
    return f"{schema}.{table}"


class SQLTool:
    """
    Minimal NL→SQL helper wired to my_db.masterdataset_v with a DuckDB runner.
    """

    def __init__(self, db_path: Optional[str] = None):
        self.db_path = db_path or DUCKDB_PATH
        self.con = duckdb.connect(self.db_path)

    # -------------------------
    # SQL Runner
    # -------------------------
    def run_sql(self, sql: str):
        return self.con.execute(sql).df()

    # -------------------------
    # NL → SQL
    # -------------------------
    def _nl_to_sql(
        self, message: str, schema: Optional[str] = None, table: Optional[str] = None
    ) -> Tuple[str, str]:
        """
        Returns (sql, rationale). Small template library covering common queries.
        Falls back to a filtered SELECT or a sample.
        """
        full_table = _full_table(schema, table)
        m = (message or "").strip().lower()

        def has_any(txt, words):
            return any(w in txt for w in words)

        # Extract "top N"
        limit = None
        m_top = re.search(r"\btop\s+(\d+)", m)
        if m_top:
            limit = int(m_top.group(1))

        # 1) Top N FDs by Portfolio_value
        if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(
            m, ["top", "largest", "biggest"]
        ) and has_any(m, ["portfolio value", "portfolio_value"]):
            n = limit or 10
            sql = f"""
            SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
            FROM {full_table}
            WHERE lower(product) = 'fd'
            ORDER BY Portfolio_value DESC
            LIMIT {n};
            """
            why = f"Top {n} fixed deposits by Portfolio_value from {full_table}"
            return sql, why

        # 2) Top N Assets by Portfolio_value
        if has_any(m, ["asset", "loan", "advances"]) and has_any(
            m, ["top", "largest", "biggest"]
        ) and has_any(m, ["portfolio value", "portfolio_value"]):
            n = limit or 10
            sql = f"""
            SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
            FROM {full_table}
            WHERE lower(product) = 'assets'
            ORDER BY Portfolio_value DESC
            LIMIT {n};
            """
            why = f"Top {n} assets by Portfolio_value from {full_table}"
            return sql, why

        # 3) Aggregate (SUM/AVG) by segment or currency
        if has_any(m, ["sum", "total", "avg", "average"]) and has_any(
            m, ["segment", "currency"]
        ):
            agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG"
            dim = "segments" if "segment" in m else "currency"
            sql = f"""
            SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value
            FROM {full_table}
            GROUP BY 1
            ORDER BY 2 DESC;
            """
            why = f"{agg} Portfolio_value grouped by {dim} from {full_table}"
            return sql, why

        # 4) Generic filters
        product = None
        if "fd" in m or "deposit" in m:
            product = "fd"
        elif "asset" in m or "loan" in m or "advance" in m:
            product = "assets"

        parts = [f"SELECT * FROM {full_table} WHERE 1=1"]
        why_parts = [f"Filtered rows from {full_table}"]

        if product:
            parts.append(f"AND lower(product) = '{product}'")
            why_parts.append(f"product = {product}")

        # currency filter like: "in lkr", "currency usd"
        cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m)
        if cur_match:
            cur = cur_match.group(2).upper()
            parts.append(f"AND upper(currency) = '{cur}'")
            why_parts.append(f"currency = {cur}")

        # segment filter like: "segment retail" or "for corporate"
        seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m)
        if seg_match:
            seg = seg_match.group(2).strip()
            if seg:
                parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'")
                why_parts.append(f"segments like '{seg}'")

        if limit:
            parts.append(f"LIMIT {limit}")

        fallback_sql = " ".join(parts) + ";"
        fallback_why = "; ".join(why_parts)
        return fallback_sql, fallback_why

    # Public helpers
    def query_from_nl(self, message: str):
        sql, why = self._nl_to_sql(message)
        df = self.run_sql(sql)
        return df, sql, why

    def table_exists(self, schema: Optional[str] = None, table: Optional[str] = None) -> bool:
        schema = schema or DEFAULT_SCHEMA
        table = table or DEFAULT_TABLE
        q = f"""
        SELECT COUNT(*) AS n
        FROM information_schema.tables
        WHERE table_schema = '{schema}' AND table_name = '{table}';
        """
        n = self.con.execute(q).fetchone()[0]
        return n > 0