File size: 2,991 Bytes
f4dc602 e3b4d13 f4dc602 0f166dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import os
import re
import pandas as pd
from typing import Optional
from utils.config import AppConfig
from utils.tracing import Tracer
class SQLTool:
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self.backend = cfg.sql_backend # "bigquery" or "motherduck"
if self.backend == "bigquery":
from google.cloud import bigquery
from google.oauth2 import service_account
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
if not key_json:
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
creds = service_account.Credentials.from_service_account_info(
eval(key_json) if key_json.strip().startswith("{") else {}
)
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
elif self.backend == "motherduck":
import duckdb
token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
db_name = self.cfg.motherduck_db or "default"
# Start a plain DuckDB connection
self.client = duckdb.connect()
# Ensure the MotherDuck extension is available and loaded
# (DuckDB will download it automatically in this environment)
self.client.execute("INSTALL motherduck;")
self.client.execute("LOAD motherduck;")
# Provide token and attach the remote MotherDuck database as 'md'
if not token:
raise RuntimeError("Missing MOTHERDUCK_TOKEN")
self.client.execute(f"SET motherduck_token='{token}';")
self.client.execute(f"ATTACH 'md:/{db_name}' AS md;")
self.client.execute("USE md;") # subsequent queries run against 'md' by default
else:
raise RuntimeError("Unknown SQL backend")
def _nl_to_sql(self, message: str) -> str:
# Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
# Expect users to include table names. Example: "avg revenue by month from dataset.sales"
m = message.lower()
if "avg" in m and " by " in m:
return "-- Example template; edit me\nSELECT DATE_TRUNC(month, date_col) AS month, AVG(metric) AS avg_metric FROM dataset.table GROUP BY 1 ORDER BY 1;"
# fallback: pass-through if user typed SQL explicitly
if re.match(r"^\s*select ", m):
return message
return "SELECT * FROM dataset.table LIMIT 100;"
def run(self, message: str) -> pd.DataFrame:
sql = self._nl_to_sql(message)
self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
if self.backend == "bigquery":
df = self.client.query(sql).to_dataframe()
else:
df = self.client.execute(sql).fetch_df()
return df |