|
|
import os |
|
|
import re |
|
|
import pandas as pd |
|
|
from typing import Optional |
|
|
from utils.config import AppConfig |
|
|
from utils.tracing import Tracer |
|
|
|
|
|
class SQLTool: |
|
|
def __init__(self, cfg: AppConfig, tracer: Tracer): |
|
|
self.cfg = cfg |
|
|
self.tracer = tracer |
|
|
self.backend = cfg.sql_backend |
|
|
if self.backend == "bigquery": |
|
|
from google.cloud import bigquery |
|
|
from google.oauth2 import service_account |
|
|
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON") |
|
|
if not key_json: |
|
|
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret") |
|
|
creds = service_account.Credentials.from_service_account_info( |
|
|
eval(key_json) if key_json.strip().startswith("{") else {} |
|
|
) |
|
|
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project) |
|
|
elif self.backend == "motherduck": |
|
|
import duckdb |
|
|
token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") |
|
|
db_name = self.cfg.motherduck_db or "default" |
|
|
|
|
|
|
|
|
self.client = duckdb.connect() |
|
|
|
|
|
|
|
|
|
|
|
self.client.execute("INSTALL motherduck;") |
|
|
self.client.execute("LOAD motherduck;") |
|
|
|
|
|
|
|
|
if not token: |
|
|
raise RuntimeError("Missing MOTHERDUCK_TOKEN") |
|
|
self.client.execute(f"SET motherduck_token='{token}';") |
|
|
self.client.execute(f"ATTACH 'md:/{db_name}' AS md;") |
|
|
self.client.execute("USE md;") |
|
|
else: |
|
|
raise RuntimeError("Unknown SQL backend") |
|
|
|
|
|
def _nl_to_sql(self, message: str) -> str: |
|
|
|
|
|
|
|
|
m = message.lower() |
|
|
if "avg" in m and " by " in m: |
|
|
return "-- Example template; edit me\nSELECT DATE_TRUNC(month, date_col) AS month, AVG(metric) AS avg_metric FROM dataset.table GROUP BY 1 ORDER BY 1;" |
|
|
|
|
|
if re.match(r"^\s*select ", m): |
|
|
return message |
|
|
return "SELECT * FROM dataset.table LIMIT 100;" |
|
|
|
|
|
def run(self, message: str) -> pd.DataFrame: |
|
|
sql = self._nl_to_sql(message) |
|
|
self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend}) |
|
|
if self.backend == "bigquery": |
|
|
df = self.client.query(sql).to_dataframe() |
|
|
else: |
|
|
df = self.client.execute(sql).fetch_df() |
|
|
return df |