File size: 3,017 Bytes
5d30bdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Quick end-to-end chat script for testing the agent locally.

Wires run_sql to a LangGraph ReAct agent with the semantic layer as context.
Uses MODEL_BACKEND / MINIMAX_API_KEY from .env (or environment).

Usage:
    PARQUET_DIR=data/dev uv run python scripts/chat.py
"""

import os
import sys
import warnings
from pathlib import Path

from dotenv import load_dotenv
from langchain_core.tools import tool

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    from langgraph.prebuilt import create_react_agent

load_dotenv()

# Repo root so imports work when run from any directory.
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from agent.client import get_llm  # noqa: E402
from agent.tools.inspect_schema import inspect_schema  # noqa: E402
from agent.tools.run_sql import build_connection, run_sql  # noqa: E402
from agent.tools.schemas import InspectSchemaInput, RunSqlInput  # noqa: E402

SEMANTIC_LAYER_PATH = os.getenv("SEMANTIC_LAYER_PATH", "data/semantic_layer.yml")
PARQUET_DIR = os.getenv("PARQUET_DIR", "data/parquet")

_sl_path = SEMANTIC_LAYER_PATH
_semantic_layer = Path(SEMANTIC_LAYER_PATH).read_text()
_conn = build_connection(PARQUET_DIR)


@tool
def inspect_schema_tool(table: str | None = None) -> dict:
    """List all tables / metrics / dimensions (no arg), or describe one table's columns and joins.

    Always call this first — before run_sql — to confirm table and column names.
    Returns: {tables, metrics, dimensions} overview or {table: {columns, joins}} detail.
    """
    result = inspect_schema(InspectSchemaInput(table=table), _sl_path)
    return result.model_dump()


@tool
def run_sql_tool(query: str, max_rows: int = 100) -> dict:
    """Execute a read-only SELECT (or WITH … SELECT) against DuckDB.

    Call this to answer quantitative questions about the dataset.
    Returns: {rows, truncated, row_count, execution_ms} on success,
             {error, hint} on failure — use the hint to self-correct.
    """
    result = run_sql(RunSqlInput(query=query, max_rows=max_rows), _conn)
    return result.model_dump()


SYSTEM_PROMPT = f"""You are a data analyst agent. You answer questions by querying a DuckDB database.

Always use run_sql_tool to fetch data before answering. Do not guess values.

## Semantic layer (schema + business context)

{_semantic_layer}
"""

llm = get_llm()
agent = create_react_agent(llm, tools=[inspect_schema_tool, run_sql_tool], prompt=SYSTEM_PROMPT)


def chat(question: str) -> None:
    print(f"\nQ: {question}")
    print("-" * 60)
    result = agent.invoke({"messages": [("human", question)]})
    print(result["messages"][-1].content)
    print()


if __name__ == "__main__":
    if len(sys.argv) > 1:
        chat(" ".join(sys.argv[1:]))
    else:
        print("Type your question (Ctrl-C to quit):\n")
        while True:
            try:
                q = input("> ").strip()
                if q:
                    chat(q)
            except KeyboardInterrupt:
                break