import json import os import subprocess import time from pathlib import Path from typing import Annotated, Optional, TypedDict import docker from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode from observability.langfuse_client import get_langfuse from observability.tool_tracing import trace_tool_execution from resilience.circuit_breaker import CircuitBreaker from tracking.cost_tracker import CostTracker from ui.state_manager import AgentStateManager # --------------------------------------------------------------------------- # GraphState — the shared state passed between all graph nodes # --------------------------------------------------------------------------- MULTI_FILE_KEYWORDS = ("3 files", "multiple files", "several files", "all files") class GraphState(TypedDict): messages: Annotated[list, add_messages] workspace_dir: str current_task: str error_logs: str plan: str complexity: Optional[str] review_feedback: Optional[str] current_agent: Optional[str] iteration_count: int writes_performed: bool search_call_count: int semantic_search_call_count: int tests_passed: Optional[bool] test_output: Optional[str] verification_attempts: int branch_name: Optional[str] commit_hash: Optional[str] total_cost_usd: float budget_exceeded: bool _retry_max: int _retry_delay: float last_model_used: Optional[str] langfuse_trace_id: Optional[str] # --------------------------------------------------------------------------- # Module-level singletons # --------------------------------------------------------------------------- docker_client = docker.from_env() _sandbox: docker.models.containers.Container = None _SANDBOX_LABEL = "auto-swe-agent-sandbox" # Kept outside GraphState because TypedDict can't hold arbitrary class instances. _cost_tracker: CostTracker = CostTracker(budget_usd=5.0) _circuit_breaker: CircuitBreaker = CircuitBreaker( failure_threshold=5, recovery_timeout=300 ) _circuit_events: list[str] = [] _state_manager: AgentStateManager = AgentStateManager() # Per-run counters _semantic_search_call_count: int = 0 _agent_call_counts: dict[str, int] = {} # agent_name -> call count for eval tracking _review_feedbacks: list[str] = [] # LGTM / NEEDS_FIX outcomes for eval tracking # --------------------------------------------------------------------------- # Docker sandbox # --------------------------------------------------------------------------- def get_sandbox(workspace_dir: str) -> docker.models.containers.Container: global _sandbox if _sandbox is not None: return _sandbox existing = docker_client.containers.list( filters={"label": f"role={_SANDBOX_LABEL}"} ) if existing: _sandbox = existing[0] print(f"[Docker] Reusing sandbox: {_sandbox.short_id}") return _sandbox abs_workspace = os.path.abspath(workspace_dir) _sandbox = docker_client.containers.run( "python:3.11-slim", command="sleep infinity", detach=True, remove=True, labels={"role": _SANDBOX_LABEL}, volumes={abs_workspace: {"bind": "/workspace", "mode": "rw"}}, working_dir="/workspace", ) print(f"[Docker] Sandbox started: {_sandbox.short_id}") print("[Docker] Installing packages...") exit_code, output = _sandbox.exec_run( ["pip", "install", "fastapi", "httpx", "pytest", "uvicorn"], demux=False, ) if exit_code != 0: raise RuntimeError( f"[Docker] pip install failed (exit {exit_code}):\n{output.decode()}" ) exit_code, _ = _sandbox.exec_run( ["python", "-c", "import fastapi, pytest, httpx, uvicorn"] ) if exit_code != 0: raise RuntimeError("[Docker] Health check failed.") print("[Docker] Sandbox ready.") return _sandbox # --------------------------------------------------------------------------- # Tools # --------------------------------------------------------------------------- IGNORE_DIRS = {".venv", "venv", "__pycache__", ".git", "node_modules", ".next"} @tool def list_files(directory: str) -> str: """Return a directory tree string, ignoring common non-essential directories.""" lines = [] abs_dir = os.path.abspath(directory) for root, dirs, files in os.walk(abs_dir): dirs[:] = [d for d in dirs if d not in IGNORE_DIRS] level = ( len(os.path.relpath(root, abs_dir).split(os.sep)) - 1 if root != abs_dir else 0 ) lines.append(f"{' ' * level}{os.path.basename(root) or root}/") for f in files: lines.append(f"{' ' * (level + 1)}{f}") return "\n".join(lines) if lines else "Directory is empty or does not exist." @tool def read_file(filepath: str) -> str: """Return file contents, truncated at 2000 lines with a warning if exceeded.""" with open(filepath, "r", errors="replace") as f: lines = f.readlines() if len(lines) > 2000: return ( "".join(lines[:2000]) + "\n\n[WARNING: File truncated at 2000 lines to save context window.]" ) return "".join(lines) @tool def search_codebase(keyword: str, directory: str) -> str: """Search for a keyword in all files under directory, returning file:line matches.""" matches = [] for root, dirs, files in os.walk(directory): dirs[:] = [d for d in dirs if d not in IGNORE_DIRS] for fname in files: fpath = os.path.join(root, fname) try: with open(fpath, "r", errors="replace") as f: for i, line in enumerate(f, 1): if keyword in line: matches.append(f"{fpath}:{i}: {line.rstrip()}") except OSError: pass return "\n".join(matches) if matches else "No matches found." @tool def write_to_file(filepath: str, content: str) -> str: """Write content to a file on the local filesystem (synced to the Docker sandbox via volume mount).""" ( os.makedirs(os.path.dirname(filepath), exist_ok=True) if os.path.dirname(filepath) else None ) with open(filepath, "w") as f: f.write(content) return f"Written to {filepath}" @tool def run_bash_command(command: str, workspace_dir: str = "./") -> str: """Execute a bash command inside the Docker sandbox (in /workspace) and return stdout + stderr.""" container = get_sandbox(workspace_dir) result = container.exec_run( ["bash", "-c", command], workdir="/workspace", demux=True ) stdout = (result.output[0] or b"").decode() stderr = (result.output[1] or b"").decode() return f"stdout:\n{stdout}\nstderr:\n{stderr}" @tool def run_tests(workspace_dir: str = "./") -> str: """Run pytest in the Docker sandbox and return the full test output.""" container = get_sandbox(workspace_dir) result = container.exec_run( ["bash", "-c", "pytest -x -q 2>&1"], workdir="/workspace", demux=False ) output = (result.output or b"").decode() return output[:2000] + "\n[TRUNCATED]" if len(output) > 2000 else output from tools.git_tools import commit_changes, create_branch, generate_pr_description from tools.semantic_search import semantic_search tools = [ list_files, read_file, search_codebase, semantic_search, write_to_file, run_bash_command, run_tests, create_branch, commit_changes, generate_pr_description, ] FALLBACK_MODELS = [ "gemini/gemini-2.0-flash", "gemini/gemini-2.0-flash-lite", "groq/llama-3.3-70b-versatile", "groq/llama3-8b-8192", ] # --------------------------------------------------------------------------- # Shared helpers (used by agent nodes and base) # --------------------------------------------------------------------------- def _export_ui_state(state: dict, node: str = "") -> None: cost_summary = _cost_tracker.get_summary() ui_state = { "iteration_count": state.get("iteration_count", 0), "current_node": node or state.get("current_node", "idle"), "current_agent": state.get("current_agent", "idle"), "tests_passed": state.get("tests_passed"), "verification_attempts": state.get("verification_attempts", 0), "total_cost_usd": cost_summary["total_cost_usd"], "budget_exceeded": state.get("budget_exceeded", False), "total_tokens": cost_summary["total_tokens"], "total_calls": cost_summary["total_calls"], "model_breakdown": cost_summary.get("model_breakdown", {}), "budget_usd": cost_summary.get("budget_usd", 0.0), "last_model_used": state.get("last_model_used", "unknown"), "branch_name": state.get("branch_name"), "commit_hash": state.get("commit_hash"), "messages_count": len(state.get("messages", [])), "circuit_status": _circuit_breaker.get_status(), "circuit_events": _circuit_events[-20:] if _circuit_events else [], "status": "running", } if state.get("budget_exceeded") or node == "end": ui_state["status"] = "completed" def _score_run(state: dict) -> None: """Score the run in Langfuse based on outcome.""" trace_id = state.get("langfuse_trace_id") langfuse = get_langfuse() if not langfuse.is_enabled() or not trace_id: return tests_passed = state.get("tests_passed") score_value = 1.0 if tests_passed else 0.0 langfuse.score( trace_id=trace_id, name="tests_passed", value=score_value, comment=f"verification_attempts={state.get('verification_attempts', 0)} | branch={state.get('branch_name')} | commit={state.get('commit_hash')}", ) if _review_feedbacks: lgtm_pct = _review_feedbacks.count("LGTM") / len(_review_feedbacks) * 100 langfuse.score( trace_id=trace_id, name="review_quality", value=lgtm_pct / 100.0, comment=f"lgtm={_review_feedbacks.count('LGTM')} / needs_fix={_review_feedbacks.count('NEEDS_FIX')}", ) if _agent_call_counts: total = sum(_agent_call_counts.values()) if total > 0: efficiency = 1.0 - min(_semantic_search_call_count / total, 1.0) langfuse.score( trace_id=trace_id, name="search_efficiency", value=efficiency, comment=f"semantic_search_calls={_semantic_search_call_count} | total_calls={total}", ) _state_manager.save_state(ui_state) # --------------------------------------------------------------------------- # Configure agents base runtime # --------------------------------------------------------------------------- from agents.base import configure_runtime executor_node = ToolNode(tools) configure_runtime( cost_tracker=_cost_tracker, circuit_breaker=_circuit_breaker, circuit_events=_circuit_events, fallback_models=FALLBACK_MODELS, tools=tools, executor_node=executor_node, export_ui_state_fn=_export_ui_state, ) # --------------------------------------------------------------------------- # Multi-agent nodes # --------------------------------------------------------------------------- from agents.coder import coder_node from agents.manager import manager_node from agents.planner import planner_node as multi_planner_node from agents.reviewer import reviewer_node # --------------------------------------------------------------------------- # Shared graph nodes (verify, git, executor) # --------------------------------------------------------------------------- def _track_tool_calls(state: GraphState) -> dict: """Wrap ToolNode to track tool usage and trace to Langfuse.""" global _semantic_search_call_count last = state["messages"][-1] writes = state.get("writes_performed", False) searches = state.get("search_call_count", 0) semantic_searches = state.get("semantic_search_call_count", 0) wrote_this_turn = False trace_id = state.get("langfuse_trace_id") tool_calls_info = [] if hasattr(last, "tool_calls"): for tc in last.tool_calls: tool_calls_info.append((tc["name"], str(tc.get("args", {}))[:200])) if tc["name"] == "write_to_file": writes = True wrote_this_turn = True if tc["name"] == "search_codebase": searches += 1 if tc["name"] == "semantic_search": semantic_searches += 1 _semantic_search_call_count += 1 if trace_id and tool_calls_info: span = get_langfuse().span( trace_id=trace_id, name="tool-execution-batch", input={"tool_calls": tool_calls_info}, ) result = executor_node.invoke(state) result["writes_performed"] = writes result["search_call_count"] = searches result["semantic_search_call_count"] = semantic_searches result["current_node"] = "executor" # Trace individual tool results if trace_id: for tc_name, tc_input in tool_calls_info: trace_tool_execution(trace_id, tc_name, tc_input, result) if trace_id and tool_calls_info: span.update(output={"status": "success", "results_count": len(tool_calls_info)}) if wrote_this_turn: result["tests_passed"] = None _export_ui_state({**state, **result}, "executor") return result def verify_code(state: GraphState) -> dict: """Run pytest in the Docker sandbox and update tests_passed / test_output.""" print("\n--- [NODE] VERIFY ---") workspace = state.get("workspace_dir", "./") container = get_sandbox(workspace) result = container.exec_run( ["bash", "-c", "pytest -x -q 2>&1"], workdir="/workspace", demux=False ) exit_code = result.exit_code output = (result.output or b"").decode() output = output[:2000] + "\n[TRUNCATED]" if len(output) > 2000 else output attempts = state.get("verification_attempts", 0) + 1 if exit_code == 0: print(f"[VERIFY] Tests PASSED (attempt {attempts})") return { "tests_passed": True, "test_output": "All tests passed.", "verification_attempts": attempts, "current_node": "verify", } else: print(f"[VERIFY] Tests FAILED (attempt {attempts}):\n{output[:300]}") error_msg = SystemMessage( content=f"Tests failed. Fix the following errors:\n{output}" ) return { "tests_passed": False, "test_output": output, "verification_attempts": attempts, "messages": [error_msg], "current_node": "verify", } def git_workflow(state: GraphState) -> dict: """Auto-create a branch and commit all changes after tests pass.""" print("\n--- [NODE] GIT WORKFLOW ---") workspace = state.get("workspace_dir", "./") timestamp = int(time.time()) branch = f"auto-swe/fix-{timestamp}" from tools.git_tools import _run_in_sandbox _run_in_sandbox( 'git config user.email "agent@auto-swe-agent" && git config user.name "auto-swe-agent"', workspace, ) exit_code, _ = _run_in_sandbox("git rev-parse --is-inside-work-tree", workspace) if exit_code != 0: print("[GIT] Not a git repo — skipping git workflow.") return { "branch_name": None, "commit_hash": None, "current_node": "git_workflow", } exit_code, out = _run_in_sandbox(f"git checkout -b {branch}", workspace) if exit_code != 0: print(f"[GIT] Branch creation failed: {out}") return { "branch_name": None, "commit_hash": None, "current_node": "git_workflow", } print(f"[GIT] Created branch: {branch}") task_slug = state.get("current_task", "fix")[:50].strip() commit_msg = f"auto-swe: {task_slug}" _run_in_sandbox("git add -A", workspace) exit_code, out = _run_in_sandbox(f'git commit -m "{commit_msg}"', workspace) if exit_code != 0: print(f"[GIT] Commit failed: {out}") return { "branch_name": branch, "commit_hash": None, "current_node": "git_workflow", } commit_hash = "" for line in out.splitlines(): if line.startswith("["): parts = line.split() if len(parts) >= 2: commit_hash = parts[1].rstrip("]") break print(f"[GIT] Committed: {commit_hash} — {commit_msg}") return { "branch_name": branch, "commit_hash": commit_hash, "current_node": "git_workflow", } # --------------------------------------------------------------------------- # Routing functions (multi-agent) # --------------------------------------------------------------------------- def route_manager(state: GraphState) -> str: target = "end" if state.get("error_logs") else "planner" log_routing(state, "manager", target) return target def route_planner(state: GraphState) -> str: target = "end" if state.get("error_logs") else "coder" log_routing(state, "planner", target) return target def route_coder(state: GraphState) -> str: if state.get("error_logs"): log_routing(state, "coder", "end") return "end" if state.get("budget_exceeded"): print("[COST] Budget exceeded — routing to end.") log_routing(state, "coder", "end") return "end" last = state["messages"][-1] if hasattr(last, "tool_calls") and last.tool_calls: log_routing(state, "coder", "executor") return "executor" task = state.get("current_task", "") is_multi_file = ( any(k in task.lower() for k in MULTI_FILE_KEYWORDS) or state.get("search_call_count", 0) > 3 ) limit = ( _max_iterations_override if _max_iterations_override else (20 if is_multi_file else 15) ) if state["iteration_count"] >= limit: print(f"[WARNING] Iteration limit ({limit}) reached. Forcing end.") log_routing(state, "coder", "end") return "end" if not state.get("writes_performed", False): print("[GUARD] No files written yet — forcing back to coder.") log_routing(state, "coder", "coder") return "coder" log_routing(state, "coder", "verify") return "verify" def route_verify(state: GraphState) -> str: if state.get("tests_passed"): log_routing(state, "verify", "reviewer") return "reviewer" if state.get("verification_attempts", 0) < 3: log_routing(state, "verify", "coder") return "coder" print("[VERIFY] Max verification attempts reached. Ending.") log_routing(state, "verify", "end") return "end" def route_reviewer(state: GraphState) -> str: review = state.get( "review_feedback", state.get("messages", [{}])[-1].content if state.get("messages") else "", ).upper() if "LGTM" in review: global _review_feedbacks _review_feedbacks.append("LGTM") log_routing(state, "reviewer", "git_workflow") return "git_workflow" _review_feedbacks.append("NEEDS_FIX") log_routing(state, "reviewer", "coder") return "coder" def route_git(state: GraphState) -> str: log_routing(state, "git_workflow", "end") return "end" def log_routing(state: GraphState, source: str, target: str) -> None: """Trace routing decisions to Langfuse.""" trace_id = state.get("langfuse_trace_id") if trace_id: langfuse = get_langfuse() if langfuse.is_enabled(): span = langfuse.span( trace_id=trace_id, name=f"routing-{source}->{target}", input={ "source": source, "target": target, "iteration": state.get("iteration_count"), "tests_passed": state.get("tests_passed"), "writes_performed": state.get("writes_performed"), "budget_exceeded": state.get("budget_exceeded"), }, ) span.update(output={"status": "routed"}) # --------------------------------------------------------------------------- # Single-agent routing (backward compatible) # --------------------------------------------------------------------------- def route_planner_single(state: GraphState) -> str: if state.get("budget_exceeded"): print("[COST] Budget exceeded — routing to end.") return "end" task = state.get("current_task", "") is_multi_file = ( any(k in task.lower() for k in MULTI_FILE_KEYWORDS) or state.get("search_call_count", 0) > 3 ) limit = ( _max_iterations_override if _max_iterations_override else (20 if is_multi_file else 15) ) if state["iteration_count"] >= limit: print(f"[WARNING] Iteration limit ({limit}) reached. Forcing end.") return "end" last = state["messages"][-1] if hasattr(last, "tool_calls") and last.tool_calls: return "executor" if not state.get("writes_performed", False): print("[GUARD] No files written yet — forcing back to planner.") return "planner" if state.get("tests_passed") is None: return "verify" return "end" def route_verify_single(state: GraphState) -> str: if state.get("tests_passed"): return "git_workflow" if state.get("verification_attempts", 0) < 3: return "planner" print("[VERIFY] Max verification attempts reached. Ending.") return "end" # --------------------------------------------------------------------------- # Single-agent node (old planner logic, kept for backward compat) # --------------------------------------------------------------------------- NO_WRITE_MSG = SystemMessage( content=( "You have not written any files yet. You MUST use write_to_file to implement " "the changes before finishing." ) ) SINGLE_AGENT_SYSTEM = """You are an autonomous coding agent that fixes bugs and implements features. \ You have access to the following tools in two categories: === SEARCH TOOLS === 1. search_codebase(keyword, directory) — Exact text matching. \ Use for finding specific strings, variable names, function references. 2. semantic_search(query, k=5) — Semantic (meaning-based) search. \ Use for finding code by concept or functionality. \ Example: 'find where user authentication is handled' → semantic_search. \ 'find all occurrences of password' → search_codebase. === FILE / EXECUTION TOOLS === 3. list_files(directory) — Show directory tree. 4. read_file(filepath) — Read a file (truncated at 2000 lines). 5. write_to_file(filepath, content) — Write/replace a file. 6. run_bash_command(command, workspace_dir) — Execute bash inside Docker sandbox. 7. run_tests(workspace_dir) — Run pytest inside Docker sandbox. === GIT TOOLS === 8. create_branch(branch_name) — Create a new git branch. 9. commit_changes(message) — Stage and commit all changes. 10. generate_pr_description() — Generate a PR description from the diff. === RULES === - Always run tests (run_tests) after writing code to verify correctness. - If tests fail, read the error output and fix the code. - You MUST use write_to_file at least once before declaring the task done. - Prefer semantic_search for understanding code structure and finding logic; \ use search_codebase only for exact string lookups. - You have up to 15 iterations (20 for multi-file tasks) to complete the task.""" def _invoke_model( model: str, msgs: list, max_retries: int, base_delay: float, max_delay: float ): from langchain_community.chat_models import ChatLiteLLM from agents.base import _is_transient from resilience.retry import with_retry llm = ChatLiteLLM(model=model, temperature=0).bind_tools(tools) @with_retry( max_retries=max_retries, base_delay=base_delay, max_delay=max_delay, exponential_base=2.0, retryable_exceptions=(Exception,), ) def _call(): try: return llm.invoke(msgs) except Exception as e: if _is_transient(e): raise raise return _call() def planner_node_single(state: GraphState) -> dict: trimmed = [] for msg in state["messages"][-10:]: if ( hasattr(msg, "content") and isinstance(msg.content, str) and len(msg.content) > 4000 ): from langchain_core.messages import ToolMessage if isinstance(msg, ToolMessage): msg = ToolMessage( content=msg.content[:4000] + "\n[TRUNCATED]", tool_call_id=msg.tool_call_id, ) trimmed.append(msg) extra = ( [NO_WRITE_MSG] if not state.get("writes_performed", False) and state["iteration_count"] > 0 else [] ) msgs = [SystemMessage(content=SINGLE_AGENT_SYSTEM)] + extra + trimmed for model in FALLBACK_MODELS: if ( model.startswith("gemini/") and not os.environ.get("GOOGLE_API_KEY") and not os.environ.get("GEMINI_API_KEY") ): print(f"[SKIP] {model} — no API key set.") continue if model.startswith("groq/") and not os.environ.get("GROQ_API_KEY"): print(f"[SKIP] {model} — no API key set.") continue if not _circuit_breaker.can_call(model): event = f"[CIRCUIT OPEN] Skipping {model} (cooldown active)" print(event) _circuit_events.append(event) continue print(f"\n--- [NODE] PLANNER | model={model} ---") try: response = _invoke_model( model, msgs, max_retries=state.get("_retry_max", 3), base_delay=state.get("_retry_delay", 2.0), max_delay=30.0, ) _circuit_breaker.record_success(model) state["last_model_used"] = model estimated = False usage = getattr(response, "usage_metadata", None) or getattr( response, "response_metadata", {} ).get("usage", None) if usage: input_tokens = ( getattr(usage, "prompt_token_count", None) or getattr(usage, "input_tokens", None) or (usage.get("prompt_tokens") if isinstance(usage, dict) else None) or 0 ) output_tokens = ( getattr(usage, "candidates_token_count", None) or getattr(usage, "output_tokens", None) or ( usage.get("completion_tokens") if isinstance(usage, dict) else None ) or 0 ) else: input_tokens = len(msgs) * 500 output_tokens = len(str(response.content)) // 4 estimated = True print( f"[COST] Token counts unavailable — using estimates (in={input_tokens}, out={output_tokens})" ) _cost_tracker.add_call( model, input_tokens, output_tokens, "planner", estimated ) total_cost = _cost_tracker.get_total_cost() print( f"[COST] ${total_cost:.6f} total | this call: in={input_tokens} out={output_tokens} tokens" ) if _cost_tracker.check_budget_exceeded(): print( f"[COST] Budget exceeded (${total_cost:.4f} > ${_cost_tracker.budget_usd}). Halting." ) budget_msg = SystemMessage( content=f"Budget exceeded (${total_cost:.4f} > ${_cost_tracker.budget_usd}). Halting execution." ) result = { "messages": [response, budget_msg], "iteration_count": state["iteration_count"] + 1, "total_cost_usd": total_cost, "budget_exceeded": True, "tests_passed": False, "current_node": "planner", } _export_ui_state({**state, **result}, "planner") return result result = { "messages": [response], "iteration_count": state["iteration_count"] + 1, "total_cost_usd": total_cost, "budget_exceeded": False, "current_node": "planner", } _export_ui_state({**state, **result}, "planner") return result except Exception as e: err_name = type(e).__name__ is_permanent = ( any( t in err_name for t in ( "ResourceExhausted", "RateLimit", "QuotaExceeded", "APIConnectionError", "AuthenticationError", "BadRequestError", ) ) or "Missing" in str(e) or "key" in str(e).lower() ) if not is_permanent: _circuit_breaker.record_failure(model) status = _circuit_breaker.get_status().get(model, {}) if status.get("state") == "open": event = f"[CIRCUIT OPENED] {model} after {status.get('failures')} failures" _circuit_events.append(event) print(f"[FALLBACK] {model} failed: {err_name}. Trying next model...") continue raise RuntimeError("All models in fallback chain exhausted.") # --------------------------------------------------------------------------- # Build graph (architecture selected by --single-agent flag) # --------------------------------------------------------------------------- _single_agent_mode = False _max_iterations_override = 0 app = None def _build_multi_agent_graph(): workflow = StateGraph(GraphState) workflow.add_node("manager", manager_node) workflow.add_node("planner", multi_planner_node) workflow.add_node("coder", coder_node) workflow.add_node("executor", _track_tool_calls) workflow.add_node("verify", verify_code) workflow.add_node("reviewer", reviewer_node) workflow.add_node("git_workflow", git_workflow) workflow.set_entry_point("manager") workflow.add_conditional_edges( "manager", route_manager, { "planner": "planner", "end": END, }, ) workflow.add_conditional_edges( "planner", route_planner, { "coder": "coder", "end": END, }, ) workflow.add_conditional_edges( "coder", route_coder, { "executor": "executor", "verify": "verify", "end": END, }, ) workflow.add_edge("executor", "coder") workflow.add_conditional_edges( "verify", route_verify, { "reviewer": "reviewer", "coder": "coder", "end": END, }, ) workflow.add_conditional_edges( "reviewer", route_reviewer, { "git_workflow": "git_workflow", "coder": "coder", }, ) workflow.add_conditional_edges( "git_workflow", route_git, { "end": END, }, ) return workflow.compile() def _build_single_agent_graph(): workflow = StateGraph(GraphState) workflow.add_node("planner", planner_node_single) workflow.add_node("executor", _track_tool_calls) workflow.add_node("verify", verify_code) workflow.add_node("git_workflow", git_workflow) workflow.set_entry_point("planner") workflow.add_conditional_edges( "planner", route_planner_single, { "executor": "executor", "end": END, "planner": "planner", "verify": "verify", }, ) workflow.add_edge("executor", "planner") workflow.add_conditional_edges( "verify", route_verify_single, { "planner": "planner", "git_workflow": "git_workflow", "end": END, }, ) workflow.add_edge("git_workflow", END) return workflow.compile() # --------------------------------------------------------------------------- # Main entry point # --------------------------------------------------------------------------- def main(): global _single_agent_mode, app, _semantic_search_call_count, _agent_call_counts, _review_feedbacks import argparse parser = argparse.ArgumentParser() parser.add_argument( "task", nargs="?", default=None, help="Issue description to solve" ) parser.add_argument( "--task", dest="task_alias", default=None, help="Issue description to solve (alias for positional)", ) parser.add_argument("--workspace", default="./", help="Workspace directory") parser.add_argument( "--output-dir", default=None, help="Directory to write final answer and patch" ) parser.add_argument("--budget", type=float, default=5.0) parser.add_argument( "--max-iterations", type=int, default=0, help="Max iterations (0=auto based on complexity)", ) parser.add_argument("--retry-max", type=int, default=3) parser.add_argument("--retry-delay", type=float, default=2.0) parser.add_argument("--circuit-threshold", type=int, default=5) parser.add_argument("--circuit-timeout", type=int, default=300) parser.add_argument( "--single-agent", action="store_true", help="Use single-agent mode (backward-compatible planner-only)", ) args = parser.parse_args() task = args.task or args.task_alias or input("Enter task: ") workspace = os.path.abspath(args.workspace) _single_agent_mode = args.single_agent _semantic_search_call_count = 0 _agent_call_counts = {} _review_feedbacks = [] from indexing.build_index import ensure_index_built ensure_index_built(workspace) _cost_tracker.reset() _cost_tracker.budget_usd = args.budget _circuit_breaker.reset() _circuit_breaker.failure_threshold = args.circuit_threshold _circuit_breaker.recovery_timeout = args.circuit_timeout _circuit_events.clear() # Override iteration max if specified via CLI _max_iterations_override = args.max_iterations mode = "single-agent" if _single_agent_mode else "multi-agent" iter_info = ( f"max_iterations={args.max_iterations}" if args.max_iterations else "max_iterations=auto" ) print( f"Starting agent for task: {task}\nWorkspace: {workspace}\n" f"Mode: {mode}\n" f"Budget: {'disabled' if args.budget == 0 else f'${args.budget:.2f}'}\n" f"{iter_info} | " f"Retry: max={args.retry_max} delay={args.retry_delay}s " f"| Circuit: threshold={args.circuit_threshold} timeout={args.circuit_timeout}s\n" ) # Build the appropriate graph if _single_agent_mode: app = _build_single_agent_graph() else: app = _build_multi_agent_graph() initial_state: GraphState = { "messages": [HumanMessage(content=f"Task: {task}")], "workspace_dir": workspace, "current_task": task, "error_logs": "", "plan": "", "complexity": None, "review_feedback": None, "current_agent": None, "iteration_count": 0, "writes_performed": False, "search_call_count": 0, "semantic_search_call_count": 0, "tests_passed": None, "test_output": None, "verification_attempts": 0, "branch_name": None, "commit_hash": None, "total_cost_usd": 0.0, "budget_exceeded": False, "_retry_max": args.retry_max, "_retry_delay": args.retry_delay, "last_model_used": None, "langfuse_trace_id": None, } final_state = app.invoke(initial_state) print("\n=== FINAL ANSWER ===\n") content = final_state["messages"][-1].content if isinstance(content, list): content = " ".join(p.get("text", "") for p in content if isinstance(p, dict)) print(content) summary = _cost_tracker.get_summary() most_used = summary.get("most_used_model") or "unknown" circuit_status = _circuit_breaker.get_status() open_circuits = [m for m, s in circuit_status.items() if s["state"] == "open"] if _circuit_events: print(f"\n[CIRCUIT EVENTS] ({len(_circuit_events)} total)") for ev in _circuit_events[-5:]: print(f" {ev}") if open_circuits: print(f" Circuits still open: {', '.join(open_circuits)}") _export_ui_state({**final_state, "status": "completed"}, "end") lgtm_count = _review_feedbacks.count("LGTM") needs_fix_count = _review_feedbacks.count("NEEDS_FIX") print( f"\n[SUMMARY] tests_passed={final_state.get('tests_passed')} | " f"verification_attempts={final_state.get('verification_attempts', 0)} | " f"branch_name={final_state.get('branch_name')} | " f"commit_hash={final_state.get('commit_hash')} | " f"total_cost_usd={summary['total_cost_usd']:.6f} | " f"total_tokens={summary['total_tokens']} | " f"most_used_model={most_used} | " f"circuit_events={len(_circuit_events)} | " f"circuits_open={len(open_circuits)} | " f"semantic_search_calls={_semantic_search_call_count} | " f"lgtm={lgtm_count} | " f"needs_fix={needs_fix_count}" ) # Write patch and final answer to output-dir if specified if args.output_dir: out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) # Write final answer answer_path = out_dir / "final_answer.txt" answer_path.write_text(str(content)) print(f"[OUTPUT] Final answer -> {answer_path}") # Write patch via git diff patch_result = subprocess.run( ["git", "diff", "HEAD"], cwd=workspace, capture_output=True, text=True, timeout=30, ) if patch_result.stdout.strip(): patch_path = out_dir / "patch.diff" patch_path.write_text(patch_result.stdout) print(f"[OUTPUT] Patch -> {patch_path}") else: # Check for untracked files untracked = subprocess.run( ["git", "ls-files", "--others", "--exclude-standard"], cwd=workspace, capture_output=True, text=True, timeout=30, ) if untracked.stdout.strip(): patch_path = out_dir / "patch.diff" combined = [] for f in untracked.stdout.strip().splitlines(): fpath = Path(workspace) / f if fpath.is_file(): content_f = fpath.read_text(encoding="utf-8", errors="replace") combined.append(f"--- /dev/null\n+++ b/{f}\n") for line in content_f.splitlines(): combined.append(f"+{line}\n") if combined: patch_path.write_text("".join(combined)) print(f"[OUTPUT] New file patch -> {patch_path}") # Write state summary state_path = out_dir / "state.json" state_path.write_text( json.dumps( { "tests_passed": final_state.get("tests_passed"), "verification_attempts": final_state.get( "verification_attempts", 0 ), "branch_name": final_state.get("branch_name"), "commit_hash": final_state.get("commit_hash"), "total_cost_usd": summary["total_cost_usd"], "total_tokens": summary["total_tokens"], "most_used_model": most_used, "lgtm_count": lgtm_count, "needs_fix_count": needs_fix_count, "semantic_search_calls": _semantic_search_call_count, }, indent=2, ) ) print(f"[OUTPUT] State summary -> {state_path}") # Langfuse observability: score and flush _score_run(final_state) trace_id = final_state.get("langfuse_trace_id") if trace_id: host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com") print(f"\n[LANGFUSE] Trace: {host}/trace/{trace_id}") get_langfuse().flush() if __name__ == "__main__": main()