auto-swe-agent-ui / agent.py
DevilBits's picture
fix: enforce safe empty bounds for tracking data charts and match dataframe list alignments
6085b61
import json
import os
import subprocess
import time
from pathlib import Path
from typing import Annotated, Optional, TypedDict
import docker
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from observability.langfuse_client import get_langfuse
from observability.tool_tracing import trace_tool_execution
from resilience.circuit_breaker import CircuitBreaker
from tracking.cost_tracker import CostTracker
from ui.state_manager import AgentStateManager
# ---------------------------------------------------------------------------
# GraphState — the shared state passed between all graph nodes
# ---------------------------------------------------------------------------
MULTI_FILE_KEYWORDS = ("3 files", "multiple files", "several files", "all files")
class GraphState(TypedDict):
messages: Annotated[list, add_messages]
workspace_dir: str
current_task: str
error_logs: str
plan: str
complexity: Optional[str]
review_feedback: Optional[str]
current_agent: Optional[str]
iteration_count: int
writes_performed: bool
search_call_count: int
semantic_search_call_count: int
tests_passed: Optional[bool]
test_output: Optional[str]
verification_attempts: int
branch_name: Optional[str]
commit_hash: Optional[str]
total_cost_usd: float
budget_exceeded: bool
_retry_max: int
_retry_delay: float
last_model_used: Optional[str]
langfuse_trace_id: Optional[str]
# ---------------------------------------------------------------------------
# Module-level singletons
# ---------------------------------------------------------------------------
docker_client = docker.from_env()
_sandbox: docker.models.containers.Container = None
_SANDBOX_LABEL = "auto-swe-agent-sandbox"
# Kept outside GraphState because TypedDict can't hold arbitrary class instances.
_cost_tracker: CostTracker = CostTracker(budget_usd=5.0)
_circuit_breaker: CircuitBreaker = CircuitBreaker(
failure_threshold=5, recovery_timeout=300
)
_circuit_events: list[str] = []
_state_manager: AgentStateManager = AgentStateManager()
# Per-run counters
_semantic_search_call_count: int = 0
_agent_call_counts: dict[str, int] = {} # agent_name -> call count for eval tracking
_review_feedbacks: list[str] = [] # LGTM / NEEDS_FIX outcomes for eval tracking
# ---------------------------------------------------------------------------
# Docker sandbox
# ---------------------------------------------------------------------------
def get_sandbox(workspace_dir: str) -> docker.models.containers.Container:
global _sandbox
if _sandbox is not None:
return _sandbox
existing = docker_client.containers.list(
filters={"label": f"role={_SANDBOX_LABEL}"}
)
if existing:
_sandbox = existing[0]
print(f"[Docker] Reusing sandbox: {_sandbox.short_id}")
return _sandbox
abs_workspace = os.path.abspath(workspace_dir)
_sandbox = docker_client.containers.run(
"python:3.11-slim",
command="sleep infinity",
detach=True,
remove=True,
labels={"role": _SANDBOX_LABEL},
volumes={abs_workspace: {"bind": "/workspace", "mode": "rw"}},
working_dir="/workspace",
)
print(f"[Docker] Sandbox started: {_sandbox.short_id}")
print("[Docker] Installing packages...")
exit_code, output = _sandbox.exec_run(
["pip", "install", "fastapi", "httpx", "pytest", "uvicorn"],
demux=False,
)
if exit_code != 0:
raise RuntimeError(
f"[Docker] pip install failed (exit {exit_code}):\n{output.decode()}"
)
exit_code, _ = _sandbox.exec_run(
["python", "-c", "import fastapi, pytest, httpx, uvicorn"]
)
if exit_code != 0:
raise RuntimeError("[Docker] Health check failed.")
print("[Docker] Sandbox ready.")
return _sandbox
# ---------------------------------------------------------------------------
# Tools
# ---------------------------------------------------------------------------
IGNORE_DIRS = {".venv", "venv", "__pycache__", ".git", "node_modules", ".next"}
@tool
def list_files(directory: str) -> str:
"""Return a directory tree string, ignoring common non-essential directories."""
lines = []
abs_dir = os.path.abspath(directory)
for root, dirs, files in os.walk(abs_dir):
dirs[:] = [d for d in dirs if d not in IGNORE_DIRS]
level = (
len(os.path.relpath(root, abs_dir).split(os.sep)) - 1
if root != abs_dir
else 0
)
lines.append(f"{' ' * level}{os.path.basename(root) or root}/")
for f in files:
lines.append(f"{' ' * (level + 1)}{f}")
return "\n".join(lines) if lines else "Directory is empty or does not exist."
@tool
def read_file(filepath: str) -> str:
"""Return file contents, truncated at 2000 lines with a warning if exceeded."""
with open(filepath, "r", errors="replace") as f:
lines = f.readlines()
if len(lines) > 2000:
return (
"".join(lines[:2000])
+ "\n\n[WARNING: File truncated at 2000 lines to save context window.]"
)
return "".join(lines)
@tool
def search_codebase(keyword: str, directory: str) -> str:
"""Search for a keyword in all files under directory, returning file:line matches."""
matches = []
for root, dirs, files in os.walk(directory):
dirs[:] = [d for d in dirs if d not in IGNORE_DIRS]
for fname in files:
fpath = os.path.join(root, fname)
try:
with open(fpath, "r", errors="replace") as f:
for i, line in enumerate(f, 1):
if keyword in line:
matches.append(f"{fpath}:{i}: {line.rstrip()}")
except OSError:
pass
return "\n".join(matches) if matches else "No matches found."
@tool
def write_to_file(filepath: str, content: str) -> str:
"""Write content to a file on the local filesystem (synced to the Docker sandbox via volume mount)."""
(
os.makedirs(os.path.dirname(filepath), exist_ok=True)
if os.path.dirname(filepath)
else None
)
with open(filepath, "w") as f:
f.write(content)
return f"Written to {filepath}"
@tool
def run_bash_command(command: str, workspace_dir: str = "./") -> str:
"""Execute a bash command inside the Docker sandbox (in /workspace) and return stdout + stderr."""
container = get_sandbox(workspace_dir)
result = container.exec_run(
["bash", "-c", command], workdir="/workspace", demux=True
)
stdout = (result.output[0] or b"").decode()
stderr = (result.output[1] or b"").decode()
return f"stdout:\n{stdout}\nstderr:\n{stderr}"
@tool
def run_tests(workspace_dir: str = "./") -> str:
"""Run pytest in the Docker sandbox and return the full test output."""
container = get_sandbox(workspace_dir)
result = container.exec_run(
["bash", "-c", "pytest -x -q 2>&1"], workdir="/workspace", demux=False
)
output = (result.output or b"").decode()
return output[:2000] + "\n[TRUNCATED]" if len(output) > 2000 else output
from tools.git_tools import commit_changes, create_branch, generate_pr_description
from tools.semantic_search import semantic_search
tools = [
list_files,
read_file,
search_codebase,
semantic_search,
write_to_file,
run_bash_command,
run_tests,
create_branch,
commit_changes,
generate_pr_description,
]
FALLBACK_MODELS = [
"gemini/gemini-2.0-flash",
"gemini/gemini-2.0-flash-lite",
"groq/llama-3.3-70b-versatile",
"groq/llama3-8b-8192",
]
# ---------------------------------------------------------------------------
# Shared helpers (used by agent nodes and base)
# ---------------------------------------------------------------------------
def _export_ui_state(state: dict, node: str = "") -> None:
cost_summary = _cost_tracker.get_summary()
ui_state = {
"iteration_count": state.get("iteration_count", 0),
"current_node": node or state.get("current_node", "idle"),
"current_agent": state.get("current_agent", "idle"),
"tests_passed": state.get("tests_passed"),
"verification_attempts": state.get("verification_attempts", 0),
"total_cost_usd": cost_summary["total_cost_usd"],
"budget_exceeded": state.get("budget_exceeded", False),
"total_tokens": cost_summary["total_tokens"],
"total_calls": cost_summary["total_calls"],
"model_breakdown": cost_summary.get("model_breakdown", {}),
"budget_usd": cost_summary.get("budget_usd", 0.0),
"last_model_used": state.get("last_model_used", "unknown"),
"branch_name": state.get("branch_name"),
"commit_hash": state.get("commit_hash"),
"messages_count": len(state.get("messages", [])),
"circuit_status": _circuit_breaker.get_status(),
"circuit_events": _circuit_events[-20:] if _circuit_events else [],
"status": "running",
}
if state.get("budget_exceeded") or node == "end":
ui_state["status"] = "completed"
def _score_run(state: dict) -> None:
"""Score the run in Langfuse based on outcome."""
trace_id = state.get("langfuse_trace_id")
langfuse = get_langfuse()
if not langfuse.is_enabled() or not trace_id:
return
tests_passed = state.get("tests_passed")
score_value = 1.0 if tests_passed else 0.0
langfuse.score(
trace_id=trace_id,
name="tests_passed",
value=score_value,
comment=f"verification_attempts={state.get('verification_attempts', 0)} | branch={state.get('branch_name')} | commit={state.get('commit_hash')}",
)
if _review_feedbacks:
lgtm_pct = _review_feedbacks.count("LGTM") / len(_review_feedbacks) * 100
langfuse.score(
trace_id=trace_id,
name="review_quality",
value=lgtm_pct / 100.0,
comment=f"lgtm={_review_feedbacks.count('LGTM')} / needs_fix={_review_feedbacks.count('NEEDS_FIX')}",
)
if _agent_call_counts:
total = sum(_agent_call_counts.values())
if total > 0:
efficiency = 1.0 - min(_semantic_search_call_count / total, 1.0)
langfuse.score(
trace_id=trace_id,
name="search_efficiency",
value=efficiency,
comment=f"semantic_search_calls={_semantic_search_call_count} | total_calls={total}",
)
_state_manager.save_state(ui_state)
# ---------------------------------------------------------------------------
# Configure agents base runtime
# ---------------------------------------------------------------------------
from agents.base import configure_runtime
executor_node = ToolNode(tools)
configure_runtime(
cost_tracker=_cost_tracker,
circuit_breaker=_circuit_breaker,
circuit_events=_circuit_events,
fallback_models=FALLBACK_MODELS,
tools=tools,
executor_node=executor_node,
export_ui_state_fn=_export_ui_state,
)
# ---------------------------------------------------------------------------
# Multi-agent nodes
# ---------------------------------------------------------------------------
from agents.coder import coder_node
from agents.manager import manager_node
from agents.planner import planner_node as multi_planner_node
from agents.reviewer import reviewer_node
# ---------------------------------------------------------------------------
# Shared graph nodes (verify, git, executor)
# ---------------------------------------------------------------------------
def _track_tool_calls(state: GraphState) -> dict:
"""Wrap ToolNode to track tool usage and trace to Langfuse."""
global _semantic_search_call_count
last = state["messages"][-1]
writes = state.get("writes_performed", False)
searches = state.get("search_call_count", 0)
semantic_searches = state.get("semantic_search_call_count", 0)
wrote_this_turn = False
trace_id = state.get("langfuse_trace_id")
tool_calls_info = []
if hasattr(last, "tool_calls"):
for tc in last.tool_calls:
tool_calls_info.append((tc["name"], str(tc.get("args", {}))[:200]))
if tc["name"] == "write_to_file":
writes = True
wrote_this_turn = True
if tc["name"] == "search_codebase":
searches += 1
if tc["name"] == "semantic_search":
semantic_searches += 1
_semantic_search_call_count += 1
if trace_id and tool_calls_info:
span = get_langfuse().span(
trace_id=trace_id,
name="tool-execution-batch",
input={"tool_calls": tool_calls_info},
)
result = executor_node.invoke(state)
result["writes_performed"] = writes
result["search_call_count"] = searches
result["semantic_search_call_count"] = semantic_searches
result["current_node"] = "executor"
# Trace individual tool results
if trace_id:
for tc_name, tc_input in tool_calls_info:
trace_tool_execution(trace_id, tc_name, tc_input, result)
if trace_id and tool_calls_info:
span.update(output={"status": "success", "results_count": len(tool_calls_info)})
if wrote_this_turn:
result["tests_passed"] = None
_export_ui_state({**state, **result}, "executor")
return result
def verify_code(state: GraphState) -> dict:
"""Run pytest in the Docker sandbox and update tests_passed / test_output."""
print("\n--- [NODE] VERIFY ---")
workspace = state.get("workspace_dir", "./")
container = get_sandbox(workspace)
result = container.exec_run(
["bash", "-c", "pytest -x -q 2>&1"], workdir="/workspace", demux=False
)
exit_code = result.exit_code
output = (result.output or b"").decode()
output = output[:2000] + "\n[TRUNCATED]" if len(output) > 2000 else output
attempts = state.get("verification_attempts", 0) + 1
if exit_code == 0:
print(f"[VERIFY] Tests PASSED (attempt {attempts})")
return {
"tests_passed": True,
"test_output": "All tests passed.",
"verification_attempts": attempts,
"current_node": "verify",
}
else:
print(f"[VERIFY] Tests FAILED (attempt {attempts}):\n{output[:300]}")
error_msg = SystemMessage(
content=f"Tests failed. Fix the following errors:\n{output}"
)
return {
"tests_passed": False,
"test_output": output,
"verification_attempts": attempts,
"messages": [error_msg],
"current_node": "verify",
}
def git_workflow(state: GraphState) -> dict:
"""Auto-create a branch and commit all changes after tests pass."""
print("\n--- [NODE] GIT WORKFLOW ---")
workspace = state.get("workspace_dir", "./")
timestamp = int(time.time())
branch = f"auto-swe/fix-{timestamp}"
from tools.git_tools import _run_in_sandbox
_run_in_sandbox(
'git config user.email "agent@auto-swe-agent" && git config user.name "auto-swe-agent"',
workspace,
)
exit_code, _ = _run_in_sandbox("git rev-parse --is-inside-work-tree", workspace)
if exit_code != 0:
print("[GIT] Not a git repo — skipping git workflow.")
return {
"branch_name": None,
"commit_hash": None,
"current_node": "git_workflow",
}
exit_code, out = _run_in_sandbox(f"git checkout -b {branch}", workspace)
if exit_code != 0:
print(f"[GIT] Branch creation failed: {out}")
return {
"branch_name": None,
"commit_hash": None,
"current_node": "git_workflow",
}
print(f"[GIT] Created branch: {branch}")
task_slug = state.get("current_task", "fix")[:50].strip()
commit_msg = f"auto-swe: {task_slug}"
_run_in_sandbox("git add -A", workspace)
exit_code, out = _run_in_sandbox(f'git commit -m "{commit_msg}"', workspace)
if exit_code != 0:
print(f"[GIT] Commit failed: {out}")
return {
"branch_name": branch,
"commit_hash": None,
"current_node": "git_workflow",
}
commit_hash = ""
for line in out.splitlines():
if line.startswith("["):
parts = line.split()
if len(parts) >= 2:
commit_hash = parts[1].rstrip("]")
break
print(f"[GIT] Committed: {commit_hash}{commit_msg}")
return {
"branch_name": branch,
"commit_hash": commit_hash,
"current_node": "git_workflow",
}
# ---------------------------------------------------------------------------
# Routing functions (multi-agent)
# ---------------------------------------------------------------------------
def route_manager(state: GraphState) -> str:
target = "end" if state.get("error_logs") else "planner"
log_routing(state, "manager", target)
return target
def route_planner(state: GraphState) -> str:
target = "end" if state.get("error_logs") else "coder"
log_routing(state, "planner", target)
return target
def route_coder(state: GraphState) -> str:
if state.get("error_logs"):
log_routing(state, "coder", "end")
return "end"
if state.get("budget_exceeded"):
print("[COST] Budget exceeded — routing to end.")
log_routing(state, "coder", "end")
return "end"
last = state["messages"][-1]
if hasattr(last, "tool_calls") and last.tool_calls:
log_routing(state, "coder", "executor")
return "executor"
task = state.get("current_task", "")
is_multi_file = (
any(k in task.lower() for k in MULTI_FILE_KEYWORDS)
or state.get("search_call_count", 0) > 3
)
limit = (
_max_iterations_override
if _max_iterations_override
else (20 if is_multi_file else 15)
)
if state["iteration_count"] >= limit:
print(f"[WARNING] Iteration limit ({limit}) reached. Forcing end.")
log_routing(state, "coder", "end")
return "end"
if not state.get("writes_performed", False):
print("[GUARD] No files written yet — forcing back to coder.")
log_routing(state, "coder", "coder")
return "coder"
log_routing(state, "coder", "verify")
return "verify"
def route_verify(state: GraphState) -> str:
if state.get("tests_passed"):
log_routing(state, "verify", "reviewer")
return "reviewer"
if state.get("verification_attempts", 0) < 3:
log_routing(state, "verify", "coder")
return "coder"
print("[VERIFY] Max verification attempts reached. Ending.")
log_routing(state, "verify", "end")
return "end"
def route_reviewer(state: GraphState) -> str:
review = state.get(
"review_feedback",
state.get("messages", [{}])[-1].content if state.get("messages") else "",
).upper()
if "LGTM" in review:
global _review_feedbacks
_review_feedbacks.append("LGTM")
log_routing(state, "reviewer", "git_workflow")
return "git_workflow"
_review_feedbacks.append("NEEDS_FIX")
log_routing(state, "reviewer", "coder")
return "coder"
def route_git(state: GraphState) -> str:
log_routing(state, "git_workflow", "end")
return "end"
def log_routing(state: GraphState, source: str, target: str) -> None:
"""Trace routing decisions to Langfuse."""
trace_id = state.get("langfuse_trace_id")
if trace_id:
langfuse = get_langfuse()
if langfuse.is_enabled():
span = langfuse.span(
trace_id=trace_id,
name=f"routing-{source}->{target}",
input={
"source": source,
"target": target,
"iteration": state.get("iteration_count"),
"tests_passed": state.get("tests_passed"),
"writes_performed": state.get("writes_performed"),
"budget_exceeded": state.get("budget_exceeded"),
},
)
span.update(output={"status": "routed"})
# ---------------------------------------------------------------------------
# Single-agent routing (backward compatible)
# ---------------------------------------------------------------------------
def route_planner_single(state: GraphState) -> str:
if state.get("budget_exceeded"):
print("[COST] Budget exceeded — routing to end.")
return "end"
task = state.get("current_task", "")
is_multi_file = (
any(k in task.lower() for k in MULTI_FILE_KEYWORDS)
or state.get("search_call_count", 0) > 3
)
limit = (
_max_iterations_override
if _max_iterations_override
else (20 if is_multi_file else 15)
)
if state["iteration_count"] >= limit:
print(f"[WARNING] Iteration limit ({limit}) reached. Forcing end.")
return "end"
last = state["messages"][-1]
if hasattr(last, "tool_calls") and last.tool_calls:
return "executor"
if not state.get("writes_performed", False):
print("[GUARD] No files written yet — forcing back to planner.")
return "planner"
if state.get("tests_passed") is None:
return "verify"
return "end"
def route_verify_single(state: GraphState) -> str:
if state.get("tests_passed"):
return "git_workflow"
if state.get("verification_attempts", 0) < 3:
return "planner"
print("[VERIFY] Max verification attempts reached. Ending.")
return "end"
# ---------------------------------------------------------------------------
# Single-agent node (old planner logic, kept for backward compat)
# ---------------------------------------------------------------------------
NO_WRITE_MSG = SystemMessage(
content=(
"You have not written any files yet. You MUST use write_to_file to implement "
"the changes before finishing."
)
)
SINGLE_AGENT_SYSTEM = """You are an autonomous coding agent that fixes bugs and implements features. \
You have access to the following tools in two categories:
=== SEARCH TOOLS ===
1. search_codebase(keyword, directory) — Exact text matching. \
Use for finding specific strings, variable names, function references.
2. semantic_search(query, k=5) — Semantic (meaning-based) search. \
Use for finding code by concept or functionality. \
Example: 'find where user authentication is handled' → semantic_search. \
'find all occurrences of password' → search_codebase.
=== FILE / EXECUTION TOOLS ===
3. list_files(directory) — Show directory tree.
4. read_file(filepath) — Read a file (truncated at 2000 lines).
5. write_to_file(filepath, content) — Write/replace a file.
6. run_bash_command(command, workspace_dir) — Execute bash inside Docker sandbox.
7. run_tests(workspace_dir) — Run pytest inside Docker sandbox.
=== GIT TOOLS ===
8. create_branch(branch_name) — Create a new git branch.
9. commit_changes(message) — Stage and commit all changes.
10. generate_pr_description() — Generate a PR description from the diff.
=== RULES ===
- Always run tests (run_tests) after writing code to verify correctness.
- If tests fail, read the error output and fix the code.
- You MUST use write_to_file at least once before declaring the task done.
- Prefer semantic_search for understanding code structure and finding logic; \
use search_codebase only for exact string lookups.
- You have up to 15 iterations (20 for multi-file tasks) to complete the task."""
def _invoke_model(
model: str, msgs: list, max_retries: int, base_delay: float, max_delay: float
):
from langchain_community.chat_models import ChatLiteLLM
from agents.base import _is_transient
from resilience.retry import with_retry
llm = ChatLiteLLM(model=model, temperature=0).bind_tools(tools)
@with_retry(
max_retries=max_retries,
base_delay=base_delay,
max_delay=max_delay,
exponential_base=2.0,
retryable_exceptions=(Exception,),
)
def _call():
try:
return llm.invoke(msgs)
except Exception as e:
if _is_transient(e):
raise
raise
return _call()
def planner_node_single(state: GraphState) -> dict:
trimmed = []
for msg in state["messages"][-10:]:
if (
hasattr(msg, "content")
and isinstance(msg.content, str)
and len(msg.content) > 4000
):
from langchain_core.messages import ToolMessage
if isinstance(msg, ToolMessage):
msg = ToolMessage(
content=msg.content[:4000] + "\n[TRUNCATED]",
tool_call_id=msg.tool_call_id,
)
trimmed.append(msg)
extra = (
[NO_WRITE_MSG]
if not state.get("writes_performed", False) and state["iteration_count"] > 0
else []
)
msgs = [SystemMessage(content=SINGLE_AGENT_SYSTEM)] + extra + trimmed
for model in FALLBACK_MODELS:
if (
model.startswith("gemini/")
and not os.environ.get("GOOGLE_API_KEY")
and not os.environ.get("GEMINI_API_KEY")
):
print(f"[SKIP] {model} — no API key set.")
continue
if model.startswith("groq/") and not os.environ.get("GROQ_API_KEY"):
print(f"[SKIP] {model} — no API key set.")
continue
if not _circuit_breaker.can_call(model):
event = f"[CIRCUIT OPEN] Skipping {model} (cooldown active)"
print(event)
_circuit_events.append(event)
continue
print(f"\n--- [NODE] PLANNER | model={model} ---")
try:
response = _invoke_model(
model,
msgs,
max_retries=state.get("_retry_max", 3),
base_delay=state.get("_retry_delay", 2.0),
max_delay=30.0,
)
_circuit_breaker.record_success(model)
state["last_model_used"] = model
estimated = False
usage = getattr(response, "usage_metadata", None) or getattr(
response, "response_metadata", {}
).get("usage", None)
if usage:
input_tokens = (
getattr(usage, "prompt_token_count", None)
or getattr(usage, "input_tokens", None)
or (usage.get("prompt_tokens") if isinstance(usage, dict) else None)
or 0
)
output_tokens = (
getattr(usage, "candidates_token_count", None)
or getattr(usage, "output_tokens", None)
or (
usage.get("completion_tokens")
if isinstance(usage, dict)
else None
)
or 0
)
else:
input_tokens = len(msgs) * 500
output_tokens = len(str(response.content)) // 4
estimated = True
print(
f"[COST] Token counts unavailable — using estimates (in={input_tokens}, out={output_tokens})"
)
_cost_tracker.add_call(
model, input_tokens, output_tokens, "planner", estimated
)
total_cost = _cost_tracker.get_total_cost()
print(
f"[COST] ${total_cost:.6f} total | this call: in={input_tokens} out={output_tokens} tokens"
)
if _cost_tracker.check_budget_exceeded():
print(
f"[COST] Budget exceeded (${total_cost:.4f} > ${_cost_tracker.budget_usd}). Halting."
)
budget_msg = SystemMessage(
content=f"Budget exceeded (${total_cost:.4f} > ${_cost_tracker.budget_usd}). Halting execution."
)
result = {
"messages": [response, budget_msg],
"iteration_count": state["iteration_count"] + 1,
"total_cost_usd": total_cost,
"budget_exceeded": True,
"tests_passed": False,
"current_node": "planner",
}
_export_ui_state({**state, **result}, "planner")
return result
result = {
"messages": [response],
"iteration_count": state["iteration_count"] + 1,
"total_cost_usd": total_cost,
"budget_exceeded": False,
"current_node": "planner",
}
_export_ui_state({**state, **result}, "planner")
return result
except Exception as e:
err_name = type(e).__name__
is_permanent = (
any(
t in err_name
for t in (
"ResourceExhausted",
"RateLimit",
"QuotaExceeded",
"APIConnectionError",
"AuthenticationError",
"BadRequestError",
)
)
or "Missing" in str(e)
or "key" in str(e).lower()
)
if not is_permanent:
_circuit_breaker.record_failure(model)
status = _circuit_breaker.get_status().get(model, {})
if status.get("state") == "open":
event = f"[CIRCUIT OPENED] {model} after {status.get('failures')} failures"
_circuit_events.append(event)
print(f"[FALLBACK] {model} failed: {err_name}. Trying next model...")
continue
raise RuntimeError("All models in fallback chain exhausted.")
# ---------------------------------------------------------------------------
# Build graph (architecture selected by --single-agent flag)
# ---------------------------------------------------------------------------
_single_agent_mode = False
_max_iterations_override = 0
app = None
def _build_multi_agent_graph():
workflow = StateGraph(GraphState)
workflow.add_node("manager", manager_node)
workflow.add_node("planner", multi_planner_node)
workflow.add_node("coder", coder_node)
workflow.add_node("executor", _track_tool_calls)
workflow.add_node("verify", verify_code)
workflow.add_node("reviewer", reviewer_node)
workflow.add_node("git_workflow", git_workflow)
workflow.set_entry_point("manager")
workflow.add_conditional_edges(
"manager",
route_manager,
{
"planner": "planner",
"end": END,
},
)
workflow.add_conditional_edges(
"planner",
route_planner,
{
"coder": "coder",
"end": END,
},
)
workflow.add_conditional_edges(
"coder",
route_coder,
{
"executor": "executor",
"verify": "verify",
"end": END,
},
)
workflow.add_edge("executor", "coder")
workflow.add_conditional_edges(
"verify",
route_verify,
{
"reviewer": "reviewer",
"coder": "coder",
"end": END,
},
)
workflow.add_conditional_edges(
"reviewer",
route_reviewer,
{
"git_workflow": "git_workflow",
"coder": "coder",
},
)
workflow.add_conditional_edges(
"git_workflow",
route_git,
{
"end": END,
},
)
return workflow.compile()
def _build_single_agent_graph():
workflow = StateGraph(GraphState)
workflow.add_node("planner", planner_node_single)
workflow.add_node("executor", _track_tool_calls)
workflow.add_node("verify", verify_code)
workflow.add_node("git_workflow", git_workflow)
workflow.set_entry_point("planner")
workflow.add_conditional_edges(
"planner",
route_planner_single,
{
"executor": "executor",
"end": END,
"planner": "planner",
"verify": "verify",
},
)
workflow.add_edge("executor", "planner")
workflow.add_conditional_edges(
"verify",
route_verify_single,
{
"planner": "planner",
"git_workflow": "git_workflow",
"end": END,
},
)
workflow.add_edge("git_workflow", END)
return workflow.compile()
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def main():
global _single_agent_mode, app, _semantic_search_call_count, _agent_call_counts, _review_feedbacks
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"task", nargs="?", default=None, help="Issue description to solve"
)
parser.add_argument(
"--task",
dest="task_alias",
default=None,
help="Issue description to solve (alias for positional)",
)
parser.add_argument("--workspace", default="./", help="Workspace directory")
parser.add_argument(
"--output-dir", default=None, help="Directory to write final answer and patch"
)
parser.add_argument("--budget", type=float, default=5.0)
parser.add_argument(
"--max-iterations",
type=int,
default=0,
help="Max iterations (0=auto based on complexity)",
)
parser.add_argument("--retry-max", type=int, default=3)
parser.add_argument("--retry-delay", type=float, default=2.0)
parser.add_argument("--circuit-threshold", type=int, default=5)
parser.add_argument("--circuit-timeout", type=int, default=300)
parser.add_argument(
"--single-agent",
action="store_true",
help="Use single-agent mode (backward-compatible planner-only)",
)
args = parser.parse_args()
task = args.task or args.task_alias or input("Enter task: ")
workspace = os.path.abspath(args.workspace)
_single_agent_mode = args.single_agent
_semantic_search_call_count = 0
_agent_call_counts = {}
_review_feedbacks = []
from indexing.build_index import ensure_index_built
ensure_index_built(workspace)
_cost_tracker.reset()
_cost_tracker.budget_usd = args.budget
_circuit_breaker.reset()
_circuit_breaker.failure_threshold = args.circuit_threshold
_circuit_breaker.recovery_timeout = args.circuit_timeout
_circuit_events.clear()
# Override iteration max if specified via CLI
_max_iterations_override = args.max_iterations
mode = "single-agent" if _single_agent_mode else "multi-agent"
iter_info = (
f"max_iterations={args.max_iterations}"
if args.max_iterations
else "max_iterations=auto"
)
print(
f"Starting agent for task: {task}\nWorkspace: {workspace}\n"
f"Mode: {mode}\n"
f"Budget: {'disabled' if args.budget == 0 else f'${args.budget:.2f}'}\n"
f"{iter_info} | "
f"Retry: max={args.retry_max} delay={args.retry_delay}s "
f"| Circuit: threshold={args.circuit_threshold} timeout={args.circuit_timeout}s\n"
)
# Build the appropriate graph
if _single_agent_mode:
app = _build_single_agent_graph()
else:
app = _build_multi_agent_graph()
initial_state: GraphState = {
"messages": [HumanMessage(content=f"Task: {task}")],
"workspace_dir": workspace,
"current_task": task,
"error_logs": "",
"plan": "",
"complexity": None,
"review_feedback": None,
"current_agent": None,
"iteration_count": 0,
"writes_performed": False,
"search_call_count": 0,
"semantic_search_call_count": 0,
"tests_passed": None,
"test_output": None,
"verification_attempts": 0,
"branch_name": None,
"commit_hash": None,
"total_cost_usd": 0.0,
"budget_exceeded": False,
"_retry_max": args.retry_max,
"_retry_delay": args.retry_delay,
"last_model_used": None,
"langfuse_trace_id": None,
}
final_state = app.invoke(initial_state)
print("\n=== FINAL ANSWER ===\n")
content = final_state["messages"][-1].content
if isinstance(content, list):
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
print(content)
summary = _cost_tracker.get_summary()
most_used = summary.get("most_used_model") or "unknown"
circuit_status = _circuit_breaker.get_status()
open_circuits = [m for m, s in circuit_status.items() if s["state"] == "open"]
if _circuit_events:
print(f"\n[CIRCUIT EVENTS] ({len(_circuit_events)} total)")
for ev in _circuit_events[-5:]:
print(f" {ev}")
if open_circuits:
print(f" Circuits still open: {', '.join(open_circuits)}")
_export_ui_state({**final_state, "status": "completed"}, "end")
lgtm_count = _review_feedbacks.count("LGTM")
needs_fix_count = _review_feedbacks.count("NEEDS_FIX")
print(
f"\n[SUMMARY] tests_passed={final_state.get('tests_passed')} | "
f"verification_attempts={final_state.get('verification_attempts', 0)} | "
f"branch_name={final_state.get('branch_name')} | "
f"commit_hash={final_state.get('commit_hash')} | "
f"total_cost_usd={summary['total_cost_usd']:.6f} | "
f"total_tokens={summary['total_tokens']} | "
f"most_used_model={most_used} | "
f"circuit_events={len(_circuit_events)} | "
f"circuits_open={len(open_circuits)} | "
f"semantic_search_calls={_semantic_search_call_count} | "
f"lgtm={lgtm_count} | "
f"needs_fix={needs_fix_count}"
)
# Write patch and final answer to output-dir if specified
if args.output_dir:
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# Write final answer
answer_path = out_dir / "final_answer.txt"
answer_path.write_text(str(content))
print(f"[OUTPUT] Final answer -> {answer_path}")
# Write patch via git diff
patch_result = subprocess.run(
["git", "diff", "HEAD"],
cwd=workspace,
capture_output=True,
text=True,
timeout=30,
)
if patch_result.stdout.strip():
patch_path = out_dir / "patch.diff"
patch_path.write_text(patch_result.stdout)
print(f"[OUTPUT] Patch -> {patch_path}")
else:
# Check for untracked files
untracked = subprocess.run(
["git", "ls-files", "--others", "--exclude-standard"],
cwd=workspace,
capture_output=True,
text=True,
timeout=30,
)
if untracked.stdout.strip():
patch_path = out_dir / "patch.diff"
combined = []
for f in untracked.stdout.strip().splitlines():
fpath = Path(workspace) / f
if fpath.is_file():
content_f = fpath.read_text(encoding="utf-8", errors="replace")
combined.append(f"--- /dev/null\n+++ b/{f}\n")
for line in content_f.splitlines():
combined.append(f"+{line}\n")
if combined:
patch_path.write_text("".join(combined))
print(f"[OUTPUT] New file patch -> {patch_path}")
# Write state summary
state_path = out_dir / "state.json"
state_path.write_text(
json.dumps(
{
"tests_passed": final_state.get("tests_passed"),
"verification_attempts": final_state.get(
"verification_attempts", 0
),
"branch_name": final_state.get("branch_name"),
"commit_hash": final_state.get("commit_hash"),
"total_cost_usd": summary["total_cost_usd"],
"total_tokens": summary["total_tokens"],
"most_used_model": most_used,
"lgtm_count": lgtm_count,
"needs_fix_count": needs_fix_count,
"semantic_search_calls": _semantic_search_call_count,
},
indent=2,
)
)
print(f"[OUTPUT] State summary -> {state_path}")
# Langfuse observability: score and flush
_score_run(final_state)
trace_id = final_state.get("langfuse_trace_id")
if trace_id:
host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
print(f"\n[LANGFUSE] Trace: {host}/trace/{trace_id}")
get_langfuse().flush()
if __name__ == "__main__":
main()