File size: 7,131 Bytes
1cdb3e3
 
 
6a2abaa
1cdb3e3
 
 
 
6a2abaa
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2abaa
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2abaa
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2abaa
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2abaa
1cdb3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""Custom ReACT engine using LiteLLM's native function calling with Z.ai GLM-5.

CrewAI doesn't reliably route tools through to Z.ai via LiteLLM, so this module
bypasses CrewAI entirely for the tool-calling loop.
"""

import json
import logging
import time

import litellm

from code_tribunal.config import TribunalConfig

log = logging.getLogger("code_tribunal.react")

_MAX_RETRIES = 5
_BASE_DELAY = 4.0


def _completion_with_retry(**kwargs):
    """Call litellm.completion with exponential backoff on rate-limit errors."""
    for attempt in range(_MAX_RETRIES):
        try:
            return litellm.completion(**kwargs)
        except litellm.RateLimitError:
            if attempt == _MAX_RETRIES - 1:
                raise
            delay = _BASE_DELAY * (2 ** attempt)
            log.warning("[RETRY] Rate limited (attempt %d/%d), waiting %.0fs...", attempt + 1, _MAX_RETRIES, delay)
            time.sleep(delay)


def _build_tool_schemas(tools: list) -> list[dict]:
    """Convert CrewAI BaseTool instances to OpenAI function calling schema."""
    schemas = []
    for tool in tools:
        schema = tool.args_schema.model_json_schema()
        properties = schema.get("properties", {})
        for prop in properties.values():
            prop.pop("title", None)
        schemas.append({
            "type": "function",
            "function": {
                "name": tool.name,
                "description": tool.description.split("\n")[0],
                "parameters": {
                    "type": "object",
                    "properties": properties,
                    "required": schema.get("required", []),
                },
            },
        })
    return schemas


def _execute_tool(tools: list, tool_name: str, arguments: dict) -> str:
    """Find and execute a tool by name."""
    for tool in tools:
        if tool.name == tool_name:
            return tool._run(**arguments)
    return f"Error: Unknown tool '{tool_name}'"


def react_loop(
    config: TribunalConfig,
    task_description: str,
    agent_role: str,
    agent_goal: str,
    tools: list,
    max_iterations: int = 10,
) -> str:
    """Run a full ReACT loop using function calling.

    The agent receives tools, decides which to call, observes the results,
    and iterates until it has enough information to answer.

    Returns the final text output.
    """
    tool_schemas = _build_tool_schemas(tools)

    system_prompt = (
        f"You are {agent_role}. {agent_goal}\n\n"
        "You have access to tools for investigating code. Use them actively:\n"
        "- Call file_reader to read specific files\n"
        "- Call pattern_search to run GritQL patterns\n"
        "- Call code_graph_query to trace call chains and dependencies\n"
        "- Call finding_context to see surrounding code for a finding\n\n"
        "Always call at least one tool before giving your final answer. "
        "After gathering information, provide a detailed analysis."
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": task_description},
    ]

    for iteration in range(max_iterations):
        response = _completion_with_retry(
            model=config.model_name,
            messages=messages,
            tools=tool_schemas,
            tool_choice="auto",
            api_key=config.api_key,
            api_base=config.api_base,
            temperature=config.temperature,
        )

        message = response.choices[0].message
        messages.append(message.model_dump())

        if not message.tool_calls:
            return message.content or ""

        for tool_call in message.tool_calls:
            func_name = tool_call.function.name
            try:
                func_args = json.loads(tool_call.function.arguments)
            except json.JSONDecodeError:
                func_args = {}

            log.debug("  [ReACT %d] %s(%s)", iteration + 1, func_name, func_args)
            result = _execute_tool(tools, func_name, func_args)

            messages.append({
                "role": "tool",
                "content": str(result),
                "tool_call_id": tool_call.id,
            })

    response = _completion_with_retry(
        model=config.model_name,
        messages=messages,
        api_key=config.api_key,
        api_base=config.api_base,
        temperature=config.temperature,
    )
    return response.choices[0].message.content or ""


def react_loop_stream(
    config: TribunalConfig,
    task_description: str,
    agent_role: str,
    agent_goal: str,
    tools: list,
    max_iterations: int = 10,
):
    """Streaming ReACT loop — yields (role, delta_text, is_tool_call) tuples.

    Each tool call is reported as a small delta. The final answer is yielded
    as a large delta.
    """
    tool_schemas = _build_tool_schemas(tools)

    system_prompt = (
        f"You are {agent_role}. {agent_goal}\n\n"
        "You have access to tools for investigating code. Use them actively:\n"
        "- Call file_reader to read specific files\n"
        "- Call pattern_search to run GritQL patterns\n"
        "- Call code_graph_query to trace call chains and dependencies\n"
        "- Call finding_context to see surrounding code for a finding\n\n"
        "Always call at least one tool before giving your final answer. "
        "After gathering information, provide a detailed analysis."
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": task_description},
    ]

    for iteration in range(max_iterations):
        response = _completion_with_retry(
            model=config.model_name,
            messages=messages,
            tools=tool_schemas,
            tool_choice="auto",
            api_key=config.api_key,
            api_base=config.api_base,
            temperature=config.temperature,
        )

        message = response.choices[0].message
        messages.append(message.model_dump())

        if message.tool_calls:
            for tool_call in message.tool_calls:
                func_name = tool_call.function.name
                try:
                    func_args = json.loads(tool_call.function.arguments)
                except json.JSONDecodeError:
                    func_args = {}
                yield (agent_role, f"\n[Using tool: {func_name}({json.dumps(func_args)})]\n", True)

                result = _execute_tool(tools, func_name, func_args)
                messages.append({
                    "role": "tool",
                    "content": str(result),
                    "tool_call_id": tool_call.id,
                })

        if not message.tool_calls:
            if message.content:
                yield (agent_role, message.content, False)
            break
    else:
        response = _completion_with_retry(
            model=config.model_name,
            messages=messages,
            api_key=config.api_key,
            api_base=config.api_base,
            temperature=config.temperature,
        )
        yield (agent_role, response.choices[0].message.content or "", False)