File size: 5,093 Bytes
0d9239a
f4dc602
0d9239a
 
f4dc602
0d9239a
e002acf
f69d955
6e66f3a
0d9239a
f69d955
 
 
 
 
 
 
 
 
 
 
0d9239a
 
f69d955
0d9239a
 
 
f69d955
0d9239a
 
 
 
f69d955
0d9239a
f69d955
 
 
0d9239a
 
 
f69d955
0d9239a
f69d955
 
 
0d9239a
 
 
 
 
 
 
 
 
 
 
f69d955
0d9239a
 
 
 
 
 
 
 
 
 
 
 
 
 
f69d955
0d9239a
 
 
 
 
 
 
 
 
 
 
 
 
 
f69d955
0d9239a
 
 
 
 
 
 
 
 
 
 
 
 
 
f69d955
0d9239a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f69d955
 
 
0d9239a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# tools/sql_tool.py
import os
import re
from typing import Optional, Tuple

import duckdb

# DuckDB connection file
DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")

# Fully qualified schema path confirmed from your server
# my_db.main.masterdataset_v
DEFAULT_DB      = os.getenv("SQL_DEFAULT_DB", "my_db")
DEFAULT_SCHEMA  = os.getenv("SQL_DEFAULT_SCHEMA", "main")
DEFAULT_TABLE   = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v")

def _full_table(db: Optional[str] = None,
                schema: Optional[str] = None,
                table: Optional[str] = None) -> str:
    """Return fully qualified <db>.<schema>.<table>"""
    db = db or DEFAULT_DB
    schema = schema or DEFAULT_SCHEMA
    table = table or DEFAULT_TABLE
    return f"{db}.{schema}.{table}"


class SQLTool:
    """Natural-language → SQL helper for DuckDB"""

    def __init__(self, db_path: Optional[str] = None):
        self.db_path = db_path or DUCKDB_PATH
        self.con = duckdb.connect(self.db_path)
        self.full_table = _full_table()

    # ------------------------------------------------------------
    # Run SQL directly
    # ------------------------------------------------------------
    def run_sql(self, sql: str):
        return self.con.execute(sql).df()

    # ------------------------------------------------------------
    # NL → SQL
    # ------------------------------------------------------------
    def _nl_to_sql(self, message: str) -> Tuple[str, str]:
        full_table = self.full_table
        m = (message or "").strip().lower()

        def has_any(txt, words):
            return any(w in txt for w in words)

        # Extract "top N"
        limit = None
        m_top = re.search(r"\btop\s+(\d+)", m)
        if m_top:
            limit = int(m_top.group(1))

        # 1. Top N FDs
        if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(
            m, ["top", "largest", "biggest"]
        ) and has_any(m, ["portfolio value", "portfolio_value"]):
            n = limit or 10
            sql = f"""
            SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
            FROM {full_table}
            WHERE lower(product) = 'fd'
            ORDER BY Portfolio_value DESC
            LIMIT {n};
            """
            why = f"Top {n} fixed deposits by Portfolio_value from {full_table}"
            return sql, why

        # 2. Top N Assets
        if has_any(m, ["asset", "loan", "advances"]) and has_any(
            m, ["top", "largest", "biggest"]
        ) and has_any(m, ["portfolio value", "portfolio_value"]):
            n = limit or 10
            sql = f"""
            SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
            FROM {full_table}
            WHERE lower(product) = 'assets'
            ORDER BY Portfolio_value DESC
            LIMIT {n};
            """
            why = f"Top {n} assets by Portfolio_value from {full_table}"
            return sql, why

        # 3. Aggregate by segment/currency
        if has_any(m, ["sum", "total", "avg", "average"]) and has_any(
            m, ["segment", "currency"]
        ):
            agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG"
            dim = "segments" if "segment" in m else "currency"
            sql = f"""
            SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value
            FROM {full_table}
            GROUP BY 1
            ORDER BY 2 DESC;
            """
            why = f"{agg} Portfolio_value grouped by {dim} from {full_table}"
            return sql, why

        # 4. Generic filters
        product = None
        if "fd" in m or "deposit" in m:
            product = "fd"
        elif "asset" in m or "loan" in m or "advance" in m:
            product = "assets"

        parts = [f"SELECT * FROM {full_table} WHERE 1=1"]
        why_parts = [f"Filtered rows from {full_table}"]

        if product:
            parts.append(f"AND lower(product) = '{product}'")
            why_parts.append(f"product = {product}")

        cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m)
        if cur_match:
            cur = cur_match.group(2).upper()
            parts.append(f"AND upper(currency) = '{cur}'")
            why_parts.append(f"currency = {cur}")

        seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m)
        if seg_match:
            seg = seg_match.group(2).strip()
            if seg:
                parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'")
                why_parts.append(f"segments like '{seg}'")

        if limit:
            parts.append(f"LIMIT {limit}")

        fallback_sql = " ".join(parts) + ";"
        fallback_why = "; ".join(why_parts)
        return fallback_sql, fallback_why

    # ------------------------------------------------------------
    # Public wrappers
    # ------------------------------------------------------------
    def query_from_nl(self, message: str):
        sql, why = self._nl_to_sql(message)
        df = self.run_sql(sql)
        return df, sql, why