File size: 4,358 Bytes
b07564d
f4dc602
 
b07564d
 
 
f4dc602
 
b07564d
f4dc602
 
 
e002acf
f4dc602
 
 
 
 
e002acf
f4dc602
 
 
e002acf
f4dc602
 
 
e002acf
 
b07564d
e002acf
f4dc602
e002acf
 
b07564d
9d6bac9
b07564d
9d6bac9
 
b07564d
 
 
9d6bac9
e002acf
 
 
 
 
 
b07564d
9d6bac9
 
 
b07564d
9d6bac9
 
b07564d
 
9d6bac9
b07564d
 
e002acf
 
 
 
 
 
 
f4dc602
 
 
 
e002acf
 
 
b07564d
e002acf
f4dc602
e002acf
b07564d
f4dc602
b07564d
 
e002acf
 
 
 
 
b07564d
 
e002acf
 
 
f4dc602
 
e002acf
 
 
f4dc602
 
 
b07564d
 
 
 
e002acf
f4dc602
 
 
b07564d
f4dc602
e002acf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# space/tools/sql_tool.py
import os
import re
import json
import shutil
import glob
import pandas as pd
from typing import Optional

from utils.config import AppConfig
from utils.tracing import Tracer


class SQLTool:
    def __init__(self, cfg: AppConfig, tracer: Tracer):
        self.cfg = cfg
        self.tracer = tracer
        self.backend = cfg.sql_backend  # "bigquery" or "motherduck"

        if self.backend == "bigquery":
            from google.cloud import bigquery
            from google.oauth2 import service_account

            key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
            if not key_json:
                raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")

            # Accept full JSON string from Space Secret
            info = json.loads(key_json) if key_json.strip().startswith("{") else {}
            creds = service_account.Credentials.from_service_account_info(info)
            self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)

        elif self.backend == "motherduck":
            import duckdb

            # ---- Enforce supported DuckDB version for MotherDuck extension ----
            if not duckdb.__version__.startswith("1.3.2"):
                raise RuntimeError(
                    f"Incompatible DuckDB version {duckdb.__version__}. "
                    "MotherDuck currently supports DuckDB 1.3.2. "
                    "Pin duckdb==1.3.2 in requirements.txt and redeploy."
                )

            token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
            db_name = self.cfg.motherduck_db or "default"
            if not token:
                raise RuntimeError("Missing MOTHERDUCK_TOKEN")

            # ---- Clean stale extension caches compiled for other DuckDB versions ----
            try:
                ext_root = os.path.expanduser("~/.duckdb/extensions")
                for p in glob.glob(os.path.join(ext_root, "*")):
                    if "1.3.2" not in p:  # keep only current version caches
                        shutil.rmtree(p, ignore_errors=True)
            except Exception:
                # best-effort cleanup; proceed even if it fails
                pass

            # ---- Connect & load MotherDuck extension ----
            self.client = duckdb.connect()  # in-memory connection; we'll ATTACH MotherDuck
            self.client.execute("INSTALL motherduck;")
            self.client.execute("LOAD motherduck;")

            # Attach the remote MotherDuck database and use it
            self.client.execute(f"SET motherduck_token='{token}';")
            self.client.execute(f"ATTACH 'md:/{db_name}' AS md;")
            self.client.execute("USE md;")  # subsequent queries run against 'md' by default
        else:
            raise RuntimeError("Unknown SQL backend")

    def _nl_to_sql(self, message: str) -> str:
        """
        Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
        Expect users to include table names. Example:
          "avg metric by month from analytics.events"
        """
        m = message.lower()

        # Very basic template example (edit table/columns to your schema)
        if "avg" in m and " by " in m:
            # DuckDB uses DATE_TRUNC('month', col); BigQuery uses DATE_TRUNC(col, MONTH).
            # This generic SQL should work in DuckDB/MotherDuck; adapt if using BigQuery.
            return (
                "-- Example template; edit me\n"
                "SELECT DATE_TRUNC('month', date_col) AS month, "
                "AVG(metric) AS avg_metric "
                "FROM analytics.table "
                "GROUP BY 1 "
                "ORDER BY 1;"
            )

        # Pass-through if the user typed SQL explicitly
        if re.match(r"^\s*select ", m):
            return message

        # Fallback
        return "SELECT * FROM analytics.table LIMIT 100;"

    def run(self, message: str) -> pd.DataFrame:
        sql = self._nl_to_sql(message)
        try:
            self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
        except Exception:
            pass

        if self.backend == "bigquery":
            df = self.client.query(sql).to_dataframe()
        else:
            # DuckDB (MotherDuck)
            df = self.client.execute(sql).fetch_df()

        return df