File size: 2,944 Bytes
a6e3889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Executes tool calls returned by the model.
The `ctx` dict carries runtime state: chat_id, send_fn, etc.
"""
import json
import asyncio
from tools import github_tool, search_tool


async def execute(tool_name: str, tool_args: dict, ctx: dict) -> str:
    """
    Run a tool and return its string result.
    ctx must contain: chat_id, send_fn (async callable)
    """
    try:
        match tool_name:

            case "web_search":
                query       = tool_args["query"]
                max_results = tool_args.get("max_results", 5)
                result = await asyncio.to_thread(
                    search_tool.search, query, max_results
                )
                # Auto-save to research/
                await asyncio.to_thread(
                    github_tool.save_research, query, result
                )
                return result

            case "github_read":
                path = tool_args["path"]
                return await asyncio.to_thread(github_tool.read_file, path)

            case "github_write":
                path    = tool_args["path"]
                content = tool_args["content"]
                msg     = tool_args.get("commit_message", "")
                sha = await asyncio.to_thread(
                    github_tool.write_file, path, content, msg
                )
                return f"Written to {path} (commit {sha[:7]})"

            case "github_list":
                folder = tool_args["folder"]
                files  = await asyncio.to_thread(github_tool.list_files, folder)
                return "\n".join(files) if files else "(empty)"

            case "quick_reply":
                message  = tool_args["message"]
                send_fn  = ctx.get("send_fn")
                if send_fn:
                    await send_fn(f"⚡ {message}")
                return f"Quick reply sent: {message}"

            case _:
                return f"Unknown tool: {tool_name}"

    except Exception as e:
        return f"Tool error ({tool_name}): {e}"


def parse_tool_calls(response_message) -> list[dict]:
    """
    Extract tool calls from a model response message.
    Returns list of dicts: {id, name, args}
    """
    calls = []
    if not hasattr(response_message, "tool_calls") or not response_message.tool_calls:
        return calls
    for tc in response_message.tool_calls:
        try:
            args = json.loads(tc.function.arguments)
        except (json.JSONDecodeError, AttributeError):
            args = {}
        calls.append({
            "id":   tc.id,
            "name": tc.function.name,
            "args": args
        })
    return calls


def tool_result_message(call_id: str, tool_name: str, result: str) -> dict:
    """Format a tool result as a message for the model's context."""
    return {
        "role":         "tool",
        "tool_call_id": call_id,
        "name":         tool_name,
        "content":      result
    }