File size: 4,358 Bytes
b07564d f4dc602 b07564d f4dc602 b07564d f4dc602 e002acf f4dc602 e002acf f4dc602 e002acf f4dc602 e002acf b07564d e002acf f4dc602 e002acf b07564d 9d6bac9 b07564d 9d6bac9 b07564d 9d6bac9 e002acf b07564d 9d6bac9 b07564d 9d6bac9 b07564d 9d6bac9 b07564d e002acf f4dc602 e002acf b07564d e002acf f4dc602 e002acf b07564d f4dc602 b07564d e002acf b07564d e002acf f4dc602 e002acf f4dc602 b07564d e002acf f4dc602 b07564d f4dc602 e002acf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# space/tools/sql_tool.py
import os
import re
import json
import shutil
import glob
import pandas as pd
from typing import Optional
from utils.config import AppConfig
from utils.tracing import Tracer
class SQLTool:
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self.backend = cfg.sql_backend # "bigquery" or "motherduck"
if self.backend == "bigquery":
from google.cloud import bigquery
from google.oauth2 import service_account
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
if not key_json:
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
# Accept full JSON string from Space Secret
info = json.loads(key_json) if key_json.strip().startswith("{") else {}
creds = service_account.Credentials.from_service_account_info(info)
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
elif self.backend == "motherduck":
import duckdb
# ---- Enforce supported DuckDB version for MotherDuck extension ----
if not duckdb.__version__.startswith("1.3.2"):
raise RuntimeError(
f"Incompatible DuckDB version {duckdb.__version__}. "
"MotherDuck currently supports DuckDB 1.3.2. "
"Pin duckdb==1.3.2 in requirements.txt and redeploy."
)
token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
db_name = self.cfg.motherduck_db or "default"
if not token:
raise RuntimeError("Missing MOTHERDUCK_TOKEN")
# ---- Clean stale extension caches compiled for other DuckDB versions ----
try:
ext_root = os.path.expanduser("~/.duckdb/extensions")
for p in glob.glob(os.path.join(ext_root, "*")):
if "1.3.2" not in p: # keep only current version caches
shutil.rmtree(p, ignore_errors=True)
except Exception:
# best-effort cleanup; proceed even if it fails
pass
# ---- Connect & load MotherDuck extension ----
self.client = duckdb.connect() # in-memory connection; we'll ATTACH MotherDuck
self.client.execute("INSTALL motherduck;")
self.client.execute("LOAD motherduck;")
# Attach the remote MotherDuck database and use it
self.client.execute(f"SET motherduck_token='{token}';")
self.client.execute(f"ATTACH 'md:/{db_name}' AS md;")
self.client.execute("USE md;") # subsequent queries run against 'md' by default
else:
raise RuntimeError("Unknown SQL backend")
def _nl_to_sql(self, message: str) -> str:
"""
Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
Expect users to include table names. Example:
"avg metric by month from analytics.events"
"""
m = message.lower()
# Very basic template example (edit table/columns to your schema)
if "avg" in m and " by " in m:
# DuckDB uses DATE_TRUNC('month', col); BigQuery uses DATE_TRUNC(col, MONTH).
# This generic SQL should work in DuckDB/MotherDuck; adapt if using BigQuery.
return (
"-- Example template; edit me\n"
"SELECT DATE_TRUNC('month', date_col) AS month, "
"AVG(metric) AS avg_metric "
"FROM analytics.table "
"GROUP BY 1 "
"ORDER BY 1;"
)
# Pass-through if the user typed SQL explicitly
if re.match(r"^\s*select ", m):
return message
# Fallback
return "SELECT * FROM analytics.table LIMIT 100;"
def run(self, message: str) -> pd.DataFrame:
sql = self._nl_to_sql(message)
try:
self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
except Exception:
pass
if self.backend == "bigquery":
df = self.client.query(sql).to_dataframe()
else:
# DuckDB (MotherDuck)
df = self.client.execute(sql).fetch_df()
return df
|