codebook / potato /coding_agent_backends /openai_backend.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
11.5 kB
"""
OpenAI Tool Use Backend
Custom coding-agent loop using any OpenAI-compatible chat-completions
server (OpenAI, vLLM, llama.cpp, etc.) with function/tool calling.
vLLM ignores the API key but the OpenAI SDK rejects an empty string, so
a non-empty placeholder is substituted for local servers. A configured
base_url is honored and normalized to the ".../v1" form the SDK expects
(accepts either the server root or an explicit "/v1" base_url).
"""
import json
import logging
import time
import threading
from typing import Dict, Iterator, List
from ..coding_agent_backend import (
CodingAgentBackend,
CodingAgentEvent,
CodingAgentEventType,
CODING_TOOLS,
execute_tool,
)
logger = logging.getLogger(__name__)
def _to_openai_tools(tools: list) -> list:
"""CODING_TOOLS is in Anthropic shape ({name, description,
input_schema}); the OpenAI/vLLM API needs
{type:"function", function:{name, description, parameters}}."""
converted = []
for t in tools:
if t.get("type") == "function" and "function" in t:
converted.append(t) # already OpenAI shape
continue
converted.append({
"type": "function",
"function": {
"name": t["name"],
"description": t.get("description", ""),
"parameters": t.get("input_schema", {"type": "object", "properties": {}}),
},
})
return converted
def _normalize_base_url(raw: str) -> str:
"""The OpenAI SDK appends '/chat/completions' to base_url, so it must
end at the '/v1' root. Accept either the server root or a '/v1' URL."""
if not raw:
return raw
u = raw.rstrip("/")
if not u.endswith("/v1"):
u = u + "/v1"
return u
class OpenAIToolUseBackend(CodingAgentBackend):
"""Agent loop using an OpenAI-compatible API with tool calling."""
def __init__(self, config: dict):
self._config = config
ai = config.get("ai_config", {})
self._model = ai.get("model", "gpt-4o-mini")
self._base_url = _normalize_base_url(ai.get("base_url", "")) or None
# vLLM/local servers ignore the key; SDK requires non-empty.
import os
self._api_key = (
ai.get("api_key")
or os.environ.get("OPENAI_API_KEY")
or "EMPTY"
)
self._max_tokens = ai.get("max_tokens", 8192)
self._temperature = ai.get("temperature", 0.3)
self._timeout = ai.get("timeout", 120)
self._max_turns = config.get("max_turns", 50)
self._tools = _to_openai_tools(CODING_TOOLS)
self._state = "idle"
self._working_dir = ""
self._messages: List[Dict] = []
self._system_prompt = ""
self._events: list = []
self._event_idx = 0
self._pause_event = threading.Event()
self._pause_event.set()
self._stop_flag = False
self._instruction_queue: list = []
self._lock = threading.Lock()
self._client = None
def _get_client(self):
if self._client is not None:
return self._client
from openai import OpenAI
kwargs = {"api_key": self._api_key, "timeout": self._timeout}
if self._base_url:
kwargs["base_url"] = self._base_url
self._client = OpenAI(**kwargs)
return self._client
def start(self, task: str, working_dir: str, system_prompt: str = "") -> None:
self._working_dir = working_dir
self._system_prompt = system_prompt or (
"You are a coding agent. You have access to tools for reading, "
"editing, and creating files, running bash commands, and searching code. "
"Use these tools to complete the task. When you are done, stop calling tools "
"and summarize what you did."
)
self._messages = [
{"role": "system", "content": self._system_prompt},
{"role": "user", "content": task},
]
self._state = "running"
self._stop_flag = False
self._events = []
self._event_idx = 0
thread = threading.Thread(target=self._run_loop, daemon=True)
thread.start()
def _run_loop(self):
"""Main agent loop using the OpenAI chat API with tools."""
turn_index = 0
try:
client = self._get_client()
while not self._stop_flag and turn_index < self._max_turns:
self._pause_event.wait()
if self._stop_flag:
break
with self._lock:
if self._instruction_queue:
instruction = self._instruction_queue.pop(0)
self._messages.append({"role": "user", "content": instruction})
self._emit(CodingAgentEventType.THINKING, {
"turn_index": turn_index,
"text": "Thinking...",
})
try:
resp = client.chat.completions.create(
model=self._model,
messages=self._messages,
tools=self._tools,
tool_choice="auto",
max_tokens=self._max_tokens,
temperature=self._temperature,
)
except Exception as e:
# Includes models/servers that don't support tools --
# surface a clear error instead of stalling the loop.
self._emit(CodingAgentEventType.ERROR, {
"message": f"OpenAI-compatible request failed: {e}"
})
self._state = "error"
return
choice = resp.choices[0].message
content = choice.content or ""
tool_calls_raw = choice.tool_calls or []
if content:
self._emit(CodingAgentEventType.THINKING, {
"turn_index": turn_index,
"text": content,
})
# Append the assistant message verbatim (must include
# tool_calls so the following tool messages pair by id).
try:
assistant_msg = choice.model_dump(exclude_none=True)
except Exception:
assistant_msg = {"role": "assistant", "content": content}
self._messages.append(assistant_msg)
tool_calls = []
for tc_raw in tool_calls_raw:
if self._stop_flag:
break
self._pause_event.wait()
if self._stop_flag:
break
fn = tc_raw.function
tool_name = fn.name or "unknown"
raw_args = fn.arguments
if isinstance(raw_args, str):
try:
tool_input = json.loads(raw_args) if raw_args else {}
except json.JSONDecodeError:
tool_input = {"command": raw_args}
elif isinstance(raw_args, dict):
tool_input = raw_args
else:
tool_input = {}
self._emit(CodingAgentEventType.TOOL_CALL_START, {
"turn_index": turn_index,
"tool": tool_name,
"input": tool_input,
})
output = execute_tool(tool_name, tool_input, self._working_dir)
output_type = self._classify_output_type(tool_name)
tc = {
"tool": tool_name,
"input": tool_input,
"output": output,
"output_type": output_type,
}
tool_calls.append(tc)
self._emit(CodingAgentEventType.TOOL_CALL_END, {
"turn_index": turn_index,
"tool_index": len(tool_calls) - 1,
**tc,
})
# OpenAI requires the tool result to reference the
# originating tool_call_id.
self._messages.append({
"role": "tool",
"tool_call_id": tc_raw.id,
"content": output,
})
self._emit(CodingAgentEventType.TURN_END, {
"turn_index": turn_index,
"content": content,
"tool_calls": tool_calls,
})
turn_index += 1
if not tool_calls_raw:
break
self._state = "completed"
self._emit(CodingAgentEventType.COMPLETE, {"total_turns": turn_index})
except Exception as e:
logger.exception("OpenAI agent loop error")
self._state = "error"
self._emit(CodingAgentEventType.ERROR, {"message": str(e)})
def _classify_output_type(self, tool_name: str) -> str:
name = tool_name.lower()
if name in ("bash", "terminal", "shell"):
return "terminal"
if name in ("edit", "replace"):
return "diff"
return "code"
def _emit(self, event_type: CodingAgentEventType, data: dict):
event = CodingAgentEvent(event_type=event_type, timestamp=time.time(), data=data)
with self._lock:
self._events.append(event)
def get_events(self) -> Iterator[CodingAgentEvent]:
while True:
with self._lock:
if self._event_idx < len(self._events):
event = self._events[self._event_idx]
self._event_idx += 1
yield event
if event.event_type in (CodingAgentEventType.COMPLETE, CodingAgentEventType.ERROR):
return
continue
if self._state in ("completed", "error"):
return
time.sleep(0.1)
def pause(self) -> None:
self._pause_event.clear()
self._state = "paused"
def resume(self) -> None:
self._state = "running"
self._pause_event.set()
def inject_instruction(self, text: str) -> None:
with self._lock:
self._instruction_queue.append(text)
def stop(self) -> None:
self._stop_flag = True
self._pause_event.set()
self._state = "completed"
def get_conversation_history(self) -> List[Dict]:
with self._lock:
return list(self._messages)
def get_state(self) -> str:
return self._state
def truncate_history(self, to_step: int) -> None:
with self._lock:
# Best-effort: keep system + initial user, then drop events
# for turns >= to_step. (Messages vary per turn with tool
# calls; keep them since OpenAI needs tool_call_id pairing.)
new_events = [
e for e in self._events
if e.data.get("turn_index", -1) < to_step
or e.data.get("turn_index", -1) == -1
]
self._events = new_events
self._event_idx = min(self._event_idx, len(self._events))