|
|
|
|
|
import os |
|
|
import re |
|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
import duckdb |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb") |
|
|
|
|
|
|
|
|
|
|
|
DUCKDB_ATTACH_SQL=ATTACH 'md:my_db' AS my_db; |
|
|
|
|
|
|
|
|
PREF_CATALOG = os.getenv("SQL_DEFAULT_DB", "my_db") |
|
|
PREF_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main") |
|
|
PREF_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v") |
|
|
|
|
|
|
|
|
class SQLTool: |
|
|
""" |
|
|
NL→SQL helper for DuckDB with: |
|
|
- optional pre-attach SQL (DUCKDB_ATTACH_SQL) |
|
|
- robust table path resolution (tries 3-part → 2-part → 1-part → information_schema scan) |
|
|
""" |
|
|
|
|
|
def __init__(self, db_path: Optional[str] = None): |
|
|
self.db_path = db_path or DUCKDB_PATH |
|
|
self.con = duckdb.connect(self.db_path) |
|
|
|
|
|
|
|
|
if DUCKDB_ATTACH_SQL: |
|
|
try: |
|
|
self.con.execute(DUCKDB_ATTACH_SQL) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"[WARN] DUCKDB_ATTACH_SQL failed: {e}") |
|
|
|
|
|
self.full_table = self._resolve_full_table(PREF_CATALOG, PREF_SCHEMA, PREF_TABLE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _try_probe(self, path: str) -> bool: |
|
|
"""Return True if SELECT * FROM <path> LIMIT 1 succeeds.""" |
|
|
try: |
|
|
self.con.execute(f"SELECT * FROM {path} LIMIT 1") |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def _scan_information_schema(self, table_name: str) -> Optional[str]: |
|
|
""" |
|
|
Look for <schema>.<table> (and <catalog>.<schema>.<table> if available) |
|
|
in information_schema. Return a best guess path string or None. |
|
|
""" |
|
|
q = """ |
|
|
SELECT table_catalog, table_schema, table_name |
|
|
FROM information_schema.tables |
|
|
WHERE lower(table_name) = ? |
|
|
ORDER BY table_catalog, table_schema |
|
|
""" |
|
|
rows = self.con.execute(q, [table_name.lower()]).fetchall() |
|
|
if not rows: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
for cat, sch, t in rows: |
|
|
if (cat or "").lower() == (PREF_CATALOG or "").lower() and sch.lower() == PREF_SCHEMA.lower(): |
|
|
candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}" |
|
|
if self._try_probe(candidate): |
|
|
return candidate |
|
|
|
|
|
|
|
|
for cat, sch, t in rows: |
|
|
if sch.lower() == PREF_SCHEMA.lower(): |
|
|
candidate = f"{sch}.{t}" |
|
|
if self._try_probe(candidate): |
|
|
return candidate |
|
|
|
|
|
|
|
|
for cat, sch, t in rows: |
|
|
candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}" |
|
|
if self._try_probe(candidate): |
|
|
return candidate |
|
|
|
|
|
return None |
|
|
|
|
|
def _resolve_full_table(self, catalog: Optional[str], schema: Optional[str], table: str) -> str: |
|
|
""" |
|
|
Return a working fully qualified path for the table by trying: |
|
|
- <catalog>.<schema>.<table> (3-part) |
|
|
- <schema>.<table> (2-part) |
|
|
- <table> (1-part) |
|
|
- information_schema scan (best effort) |
|
|
""" |
|
|
candidates: List[str] = [] |
|
|
|
|
|
if catalog: |
|
|
candidates.append(f"{catalog}.{schema}.{table}") |
|
|
if schema: |
|
|
candidates.append(f"{schema}.{table}") |
|
|
candidates.append(table) |
|
|
|
|
|
for path in candidates: |
|
|
if self._try_probe(path): |
|
|
print(f"[INFO] Using table path: {path}") |
|
|
return path |
|
|
|
|
|
|
|
|
scanned = self._scan_information_schema(table) |
|
|
if scanned: |
|
|
print(f"[INFO] Using table path (scanned): {scanned}") |
|
|
return scanned |
|
|
|
|
|
|
|
|
fallback = f"{catalog}.{schema}.{table}" if catalog else f"{schema}.{table}" |
|
|
print(f"[WARN] Could not resolve table path; falling back to: {fallback}") |
|
|
return fallback |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_sql(self, sql: str) -> pd.DataFrame: |
|
|
return self.con.execute(sql).df() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _nl_to_sql(self, message: str) -> Tuple[str, str]: |
|
|
full_table = self.full_table |
|
|
m = (message or "").strip().lower() |
|
|
|
|
|
def has_any(txt, words): |
|
|
return any(w in txt for w in words) |
|
|
|
|
|
|
|
|
limit = None |
|
|
m_top = re.search(r"\btop\s+(\d+)", m) |
|
|
if m_top: |
|
|
limit = int(m_top.group(1)) |
|
|
|
|
|
|
|
|
if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any( |
|
|
m, ["top", "largest", "biggest"] |
|
|
) and has_any(m, ["portfolio value", "portfolio_value"]): |
|
|
n = limit or 10 |
|
|
sql = f""" |
|
|
SELECT contract_number, Portfolio_value, Interest_rate, currency, segments |
|
|
FROM {full_table} |
|
|
WHERE lower(product) = 'fd' |
|
|
ORDER BY Portfolio_value DESC |
|
|
LIMIT {n}; |
|
|
""" |
|
|
why = f"Top {n} fixed deposits by Portfolio_value from {full_table}" |
|
|
return sql, why |
|
|
|
|
|
|
|
|
if has_any(m, ["asset", "loan", "advances"]) and has_any( |
|
|
m, ["top", "largest", "biggest"] |
|
|
) and has_any(m, ["portfolio value", "portfolio_value"]): |
|
|
n = limit or 10 |
|
|
sql = f""" |
|
|
SELECT contract_number, Portfolio_value, Interest_rate, currency, segments |
|
|
FROM {full_table} |
|
|
WHERE lower(product) = 'assets' |
|
|
ORDER BY Portfolio_value DESC |
|
|
LIMIT {n}; |
|
|
""" |
|
|
why = f"Top {n} assets by Portfolio_value from {full_table}" |
|
|
return sql, why |
|
|
|
|
|
|
|
|
if has_any(m, ["sum", "total", "avg", "average"]) and has_any( |
|
|
m, ["segment", "currency"] |
|
|
): |
|
|
agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG" |
|
|
dim = "segments" if "segment" in m else "currency" |
|
|
sql = f""" |
|
|
SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value |
|
|
FROM {full_table} |
|
|
GROUP BY 1 |
|
|
ORDER BY 2 DESC; |
|
|
""" |
|
|
why = f"{agg} Portfolio_value grouped by {dim} from {full_table}" |
|
|
return sql, why |
|
|
|
|
|
|
|
|
product = None |
|
|
if "fd" in m or "deposit" in m: |
|
|
product = "fd" |
|
|
elif "asset" in m or "loan" in m or "advance" in m: |
|
|
product = "assets" |
|
|
|
|
|
parts = [f"SELECT * FROM {full_table} WHERE 1=1"] |
|
|
why_parts = [f"Filtered rows from {full_table}"] |
|
|
|
|
|
if product: |
|
|
parts.append(f"AND lower(product) = '{product}'") |
|
|
why_parts.append(f"product = {product}") |
|
|
|
|
|
cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m) |
|
|
if cur_match: |
|
|
cur = cur_match.group(2).upper() |
|
|
parts.append(f"AND upper(currency) = '{cur}'") |
|
|
why_parts.append(f"currency = {cur}") |
|
|
|
|
|
seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m) |
|
|
if seg_match: |
|
|
seg = seg_match.group(2).strip() |
|
|
if seg: |
|
|
parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'") |
|
|
why_parts.append(f"segments like '{seg}'") |
|
|
|
|
|
if limit: |
|
|
parts.append(f"LIMIT {limit}") |
|
|
|
|
|
fallback_sql = " ".join(parts) + ";" |
|
|
fallback_why = "; ".join(why_parts) |
|
|
return fallback_sql, fallback_why |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_from_nl(self, message: str): |
|
|
sql, why = self._nl_to_sql(message) |
|
|
df = self.run_sql(sql) |
|
|
return df, sql, why |
|
|
|
|
|
def get_full_table_path(self) -> str: |
|
|
return self.full_table |
|
|
|