File size: 4,041 Bytes
35c0d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Assistant tools.

Two tools, exposed via LangChain's @tool decorator so either assistant can call
them through the standard tool-calling interface:

  - calculator:  safe arithmetic. Evaluated by walking a parsed AST (NOT Python's
                 eval), so it can never execute arbitrary code.
  - web_search:  Tavily web search, returning a short text digest of top hits.

`TOOLS` is the list bound to each model. `run_tool_call` executes a single
tool call (by name + args) and returns its string result.
"""

from __future__ import annotations

import ast
import operator

from langchain_core.tools import tool

from src.config import settings
from src.observability import annotate_span, observe

# --- Calculator -----------------------------------------------------------

# Only these AST node types / operators are allowed. Anything else (names,
# function calls, attribute access, etc.) is rejected, so the calculator can
# only ever do arithmetic on literal numbers.
_ALLOWED_BINOPS = {
    ast.Add: operator.add,
    ast.Sub: operator.sub,
    ast.Mult: operator.mul,
    ast.Div: operator.truediv,
    ast.FloorDiv: operator.floordiv,
    ast.Mod: operator.mod,
    ast.Pow: operator.pow,
}
_ALLOWED_UNARYOPS = {
    ast.UAdd: operator.pos,
    ast.USub: operator.neg,
}


def _eval_node(node: ast.AST) -> float:
    """Recursively evaluate a whitelisted arithmetic AST node."""
    if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
        return node.value
    if isinstance(node, ast.BinOp) and type(node.op) in _ALLOWED_BINOPS:
        return _ALLOWED_BINOPS[type(node.op)](
            _eval_node(node.left), _eval_node(node.right)
        )
    if isinstance(node, ast.UnaryOp) and type(node.op) in _ALLOWED_UNARYOPS:
        return _ALLOWED_UNARYOPS[type(node.op)](_eval_node(node.operand))
    raise ValueError("Only basic arithmetic (+ - * / // % **) is allowed.")


@tool
def calculator(expression: str) -> str:
    """Evaluate a basic arithmetic expression and return the result.

    Supports + - * / // % ** and parentheses on numbers only. Use this for any
    math instead of computing it yourself.
    """
    try:
        tree = ast.parse(expression, mode="eval")
        result = _eval_node(tree.body)
        return str(result)
    except Exception as exc:  # noqa: BLE001 - surface a clean message to the model
        return f"Calculator error: {exc}"


# --- Web search -----------------------------------------------------------


@tool
def web_search(query: str) -> str:
    """Search the web for current information and return a short text digest.

    Use this for facts that may be recent, niche, or beyond your training data.
    """
    if not settings.tavily_api_key:
        return "Web search is unavailable (TAVILY_API_KEY is not configured)."

    # Imported lazily so the tool module is importable without the dependency
    # being exercised (and without network calls at import time).
    from tavily import TavilyClient

    try:
        client = TavilyClient(api_key=settings.tavily_api_key)
        resp = client.search(query=query, max_results=3)
        results = resp.get("results", [])
        if not results:
            return "No web results found."
        lines = [
            f"- {r.get('title', 'untitled')}: {r.get('content', '').strip()}"
            for r in results
        ]
        return "\n".join(lines)
    except Exception as exc:  # noqa: BLE001
        return f"Web search error: {exc}"


# The toolset bound to both assistants.
TOOLS = [calculator, web_search]

# Lookup table so the tool-calling loop can dispatch by name.
TOOLS_BY_NAME = {t.name: t for t in TOOLS}


@observe(as_type="tool", name="tool_call")
def run_tool_call(name: str, args: dict) -> str:
    """Execute one tool call by name and return its result as a string."""
    annotate_span(metadata={"tool": name, "args": args})
    tool_obj = TOOLS_BY_NAME.get(name)
    if tool_obj is None:
        return f"Unknown tool: {name}"
    return str(tool_obj.invoke(args))