File size: 3,648 Bytes
b07564d
f4dc602
 
b07564d
f4dc602
 
 
 
e002acf
f4dc602
 
 
 
 
e002acf
f4dc602
 
 
 
 
 
e002acf
 
b07564d
e002acf
f4dc602
e002acf
 
b07564d
54614e9
9d6bac9
 
b07564d
 
9d6bac9
e002acf
 
54614e9
e002acf
 
 
54614e9
 
 
 
 
 
 
 
 
 
 
 
f4dc602
 
 
 
e002acf
 
54614e9
e002acf
f4dc602
e002acf
54614e9
f4dc602
e002acf
 
 
 
 
b07564d
 
e002acf
 
 
f4dc602
 
e002acf
 
f4dc602
 
 
b07564d
 
 
 
e002acf
f4dc602
54614e9
f4dc602
b07564d
54614e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# space/tools/sql_tool.py
import os
import re
import json
import pandas as pd
from utils.config import AppConfig
from utils.tracing import Tracer


class SQLTool:
    def __init__(self, cfg: AppConfig, tracer: Tracer):
        self.cfg = cfg
        self.tracer = tracer
        self.backend = cfg.sql_backend  # "bigquery" or "motherduck"

        if self.backend == "bigquery":
            from google.cloud import bigquery
            from google.oauth2 import service_account
            key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
            if not key_json:
                raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")

            # Accept full JSON string from Space Secret
            info = json.loads(key_json) if key_json.strip().startswith("{") else {}
            creds = service_account.Credentials.from_service_account_info(info)
            self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)

        elif self.backend == "motherduck":
            import duckdb
            # MotherDuck currently supports DuckDB 1.3.2 broadly across hosts
            if not duckdb.__version__.startswith("1.3.2"):
                raise RuntimeError(
                    f"Incompatible DuckDB version {duckdb.__version__}. "
                    "Pin duckdb==1.3.2 in requirements.txt and redeploy."
                )

            token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
            db_name = (self.cfg.motherduck_db or "workspace").strip()
            if not token:
                raise RuntimeError("Missing MOTHERDUCK_TOKEN")

            # Easiest, correct way: connect directly to MotherDuck database.
            # This will auto-download/load the extension; no manual INSTALL/LOAD/ATTACH needed.
            # Valid URIs include:
            #   "md:"                  -> connects to workspace (all DBs)
            #   f"md:{db_name}"        -> connects to a specific DB
            #   f"md:{db_name}?motherduck_token=..." -> with token in URI
            uri = f"md:{db_name}?motherduck_token={token}"
            self.client = duckdb.connect(uri)

            # Optional: set a default database context (USE) if you connected to 'md:' (workspace)
            # if db_name in ("", "workspace"):
            #     self.client.execute("USE your_database;")
        else:
            raise RuntimeError("Unknown SQL backend")

    def _nl_to_sql(self, message: str) -> str:
        """
        Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
        Edit table/column names to your schema.
        """
        m = message.lower()

        # Simple example (DuckDB/MotherDuck DATE_TRUNC flavor)
        if "avg" in m and " by " in m:
            return (
                "-- Example template; edit me\n"
                "SELECT DATE_TRUNC('month', date_col) AS month, "
                "AVG(metric) AS avg_metric "
                "FROM analytics.table "
                "GROUP BY 1 "
                "ORDER BY 1;"
            )

        # Pass-through if the user typed SQL explicitly
        if re.match(r"^\s*select ", m):
            return message

        return "SELECT * FROM analytics.table LIMIT 100;"

    def run(self, message: str) -> pd.DataFrame:
        sql = self._nl_to_sql(message)
        try:
            self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
        except Exception:
            pass

        if self.backend == "bigquery":
            return self.client.query(sql).to_dataframe()
        else:
            # DuckDB (MotherDuck)
            return self.client.execute(sql).fetch_df()