ALM_LLM / tools /sql_tool.py
AshenH's picture
Update tools/sql_tool.py
30413d2 verified
raw
history blame
8.86 kB
# tools/sql_tool.py
import os
import re
from typing import Optional, Tuple, List
import duckdb
import pandas as pd
# ------------------------------------------------------------
# Connection config
# ------------------------------------------------------------
DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
# If you need to attach a catalog (e.g., MotherDuck), put the full ATTACH here.
# Example:
DUCKDB_ATTACH_SQL=ATTACH 'md:my_db' AS my_db;
# Preferred identifiers (we will fall back automatically if they don't exist)
PREF_CATALOG = os.getenv("SQL_DEFAULT_DB", "my_db") # catalog (optional)
PREF_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main") # schema
PREF_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v") # table
class SQLTool:
"""
NL→SQL helper for DuckDB with:
- optional pre-attach SQL (DUCKDB_ATTACH_SQL)
- robust table path resolution (tries 3-part → 2-part → 1-part → information_schema scan)
"""
def __init__(self, db_path: Optional[str] = None):
self.db_path = db_path or DUCKDB_PATH
self.con = duckdb.connect(self.db_path)
# Optional: run user-supplied ATTACH (safe no-op if empty)
if DUCKDB_ATTACH_SQL:
try:
self.con.execute(DUCKDB_ATTACH_SQL)
except Exception as e:
# Don't crash the app on attach issues; we still try local tables
print(f"[WARN] DUCKDB_ATTACH_SQL failed: {e}")
self.full_table = self._resolve_full_table(PREF_CATALOG, PREF_SCHEMA, PREF_TABLE)
# ------------------------------------------------------------
# Resolution helpers
# ------------------------------------------------------------
def _try_probe(self, path: str) -> bool:
"""Return True if SELECT * FROM <path> LIMIT 1 succeeds."""
try:
self.con.execute(f"SELECT * FROM {path} LIMIT 1")
return True
except Exception:
return False
def _scan_information_schema(self, table_name: str) -> Optional[str]:
"""
Look for <schema>.<table> (and <catalog>.<schema>.<table> if available)
in information_schema. Return a best guess path string or None.
"""
q = """
SELECT table_catalog, table_schema, table_name
FROM information_schema.tables
WHERE lower(table_name) = ?
ORDER BY table_catalog, table_schema
"""
rows = self.con.execute(q, [table_name.lower()]).fetchall()
if not rows:
return None
# Prefer matches in preferred schema/catalog when possible
# 1) exact catalog+schema
for cat, sch, t in rows:
if (cat or "").lower() == (PREF_CATALOG or "").lower() and sch.lower() == PREF_SCHEMA.lower():
candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}"
if self._try_probe(candidate):
return candidate
# 2) exact schema (2-part)
for cat, sch, t in rows:
if sch.lower() == PREF_SCHEMA.lower():
candidate = f"{sch}.{t}"
if self._try_probe(candidate):
return candidate
# 3) first working row (prefer 3-part if catalog present)
for cat, sch, t in rows:
candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}"
if self._try_probe(candidate):
return candidate
return None
def _resolve_full_table(self, catalog: Optional[str], schema: Optional[str], table: str) -> str:
"""
Return a working fully qualified path for the table by trying:
- <catalog>.<schema>.<table> (3-part)
- <schema>.<table> (2-part)
- <table> (1-part)
- information_schema scan (best effort)
"""
candidates: List[str] = []
if catalog:
candidates.append(f"{catalog}.{schema}.{table}")
if schema:
candidates.append(f"{schema}.{table}")
candidates.append(table)
for path in candidates:
if self._try_probe(path):
print(f"[INFO] Using table path: {path}")
return path
# Fallback: scan information_schema
scanned = self._scan_information_schema(table)
if scanned:
print(f"[INFO] Using table path (scanned): {scanned}")
return scanned
# Last resort: keep preferred 3-part (will raise on first query)
fallback = f"{catalog}.{schema}.{table}" if catalog else f"{schema}.{table}"
print(f"[WARN] Could not resolve table path; falling back to: {fallback}")
return fallback
# ------------------------------------------------------------
# Run SQL directly
# ------------------------------------------------------------
def run_sql(self, sql: str) -> pd.DataFrame:
return self.con.execute(sql).df()
# ------------------------------------------------------------
# NL → SQL
# ------------------------------------------------------------
def _nl_to_sql(self, message: str) -> Tuple[str, str]:
full_table = self.full_table
m = (message or "").strip().lower()
def has_any(txt, words):
return any(w in txt for w in words)
# Extract "top N"
limit = None
m_top = re.search(r"\btop\s+(\d+)", m)
if m_top:
limit = int(m_top.group(1))
# 1. Top N FDs
if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(
m, ["top", "largest", "biggest"]
) and has_any(m, ["portfolio value", "portfolio_value"]):
n = limit or 10
sql = f"""
SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
FROM {full_table}
WHERE lower(product) = 'fd'
ORDER BY Portfolio_value DESC
LIMIT {n};
"""
why = f"Top {n} fixed deposits by Portfolio_value from {full_table}"
return sql, why
# 2. Top N Assets
if has_any(m, ["asset", "loan", "advances"]) and has_any(
m, ["top", "largest", "biggest"]
) and has_any(m, ["portfolio value", "portfolio_value"]):
n = limit or 10
sql = f"""
SELECT contract_number, Portfolio_value, Interest_rate, currency, segments
FROM {full_table}
WHERE lower(product) = 'assets'
ORDER BY Portfolio_value DESC
LIMIT {n};
"""
why = f"Top {n} assets by Portfolio_value from {full_table}"
return sql, why
# 3. Aggregate by segment/currency
if has_any(m, ["sum", "total", "avg", "average"]) and has_any(
m, ["segment", "currency"]
):
agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG"
dim = "segments" if "segment" in m else "currency"
sql = f"""
SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value
FROM {full_table}
GROUP BY 1
ORDER BY 2 DESC;
"""
why = f"{agg} Portfolio_value grouped by {dim} from {full_table}"
return sql, why
# 4. Generic filters
product = None
if "fd" in m or "deposit" in m:
product = "fd"
elif "asset" in m or "loan" in m or "advance" in m:
product = "assets"
parts = [f"SELECT * FROM {full_table} WHERE 1=1"]
why_parts = [f"Filtered rows from {full_table}"]
if product:
parts.append(f"AND lower(product) = '{product}'")
why_parts.append(f"product = {product}")
cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m)
if cur_match:
cur = cur_match.group(2).upper()
parts.append(f"AND upper(currency) = '{cur}'")
why_parts.append(f"currency = {cur}")
seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m)
if seg_match:
seg = seg_match.group(2).strip()
if seg:
parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'")
why_parts.append(f"segments like '{seg}'")
if limit:
parts.append(f"LIMIT {limit}")
fallback_sql = " ".join(parts) + ";"
fallback_why = "; ".join(why_parts)
return fallback_sql, fallback_why
# ------------------------------------------------------------
# Public wrappers
# ------------------------------------------------------------
def query_from_nl(self, message: str):
sql, why = self._nl_to_sql(message)
df = self.run_sql(sql)
return df, sql, why
def get_full_table_path(self) -> str:
return self.full_table