File size: 3,093 Bytes
1de7838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import ast
import json
import operator

from . import database
from .config import DB_PATH

_OPERATORS = {
    ast.Add: operator.add,
    ast.Sub: operator.sub,
    ast.Mult: operator.mul,
    ast.Div: operator.truediv,
    ast.Pow: operator.pow,
    ast.Mod: operator.mod,
    ast.USub: operator.neg,
    ast.UAdd: operator.pos,
}


def calculate(expression):
    def _eval(node):
        if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
            return node.value
        if isinstance(node, ast.BinOp) and type(node.op) in _OPERATORS:
            return _OPERATORS[type(node.op)](_eval(node.left), _eval(node.right))
        if isinstance(node, ast.UnaryOp) and type(node.op) in _OPERATORS:
            return _OPERATORS[type(node.op)](_eval(node.operand))
        raise ValueError("Unsupported expression")

    return _eval(ast.parse(expression, mode="eval").body)


class Toolbox:
    DEFINITIONS = [
        {
            "name": "list_tables",
            "description": "List all tables in the database.",
            "input_schema": {"type": "object", "properties": {}},
        },
        {
            "name": "get_schema",
            "description": "Return the columns and types of a given table.",
            "input_schema": {
                "type": "object",
                "properties": {"table": {"type": "string", "description": "Table name"}},
                "required": ["table"],
            },
        },
        {
            "name": "run_sql",
            "description": "Run a read-only SQL SELECT query and return the rows as JSON.",
            "input_schema": {
                "type": "object",
                "properties": {"query": {"type": "string", "description": "A single SELECT statement"}},
                "required": ["query"],
            },
        },
        {
            "name": "calculator",
            "description": "Evaluate a basic arithmetic expression.",
            "input_schema": {
                "type": "object",
                "properties": {"expression": {"type": "string"}},
                "required": ["expression"],
            },
        },
    ]

    def __init__(self, db_path=DB_PATH):
        self.db_path = db_path

    def execute(self, name, arguments):
        try:
            if name == "list_tables":
                return self._query(lambda c: json.dumps(database.list_tables(c)))
            if name == "get_schema":
                return self._query(lambda c: json.dumps(database.table_schema(c, arguments["table"])))
            if name == "run_sql":
                return self._query(
                    lambda c: json.dumps(database.run_select(c, arguments["query"]), default=str)
                )
            if name == "calculator":
                return str(calculate(arguments["expression"]))
            return f"Error: unknown tool '{name}'"
        except Exception as error:
            return f"Error: {error}"

    def _query(self, function):
        with database.connect(self.db_path, read_only=True) as connection:
            return function(connection)