import os import time from typing import Any, Callable, Optional from langchain_community.chat_models import ChatLiteLLM from langchain_core.messages import SystemMessage from observability.langfuse_client import get_langfuse _SKIP_ERRORS = ( "ResourceExhausted", "RateLimit", "QuotaExceeded", "APIConnectionError", "AuthenticationError", "BadRequestError", ) _TRANSIENT_ERROR_NAMES = ( "RateLimitError", "ResourceExhausted", "APIConnectionError", "Timeout", "ConnectionError", "ServiceUnavailable", "InternalServerError", ) _TRANSIENT_EXCEPTIONS = (Exception,) class AgentRuntime: cost_tracker: Any = None circuit_breaker: Any = None circuit_events: list = [] fallback_models: list[str] = [] tools: list = [] executor_node: Any = None export_ui_state_fn: Optional[Callable] = None _runtime = AgentRuntime() def configure_runtime( cost_tracker, circuit_breaker, circuit_events: list, fallback_models: list[str], tools: list, executor_node, export_ui_state_fn: Optional[Callable] = None, ) -> None: _runtime.cost_tracker = cost_tracker _runtime.circuit_breaker = circuit_breaker _runtime.circuit_events = circuit_events _runtime.fallback_models = fallback_models _runtime.tools = tools _runtime.executor_node = executor_node _runtime.export_ui_state_fn = export_ui_state_fn def get_runtime() -> AgentRuntime: return _runtime def _model_available(model: str) -> bool: if ( model.startswith("gemini/") and not os.environ.get("GOOGLE_API_KEY") and not os.environ.get("GEMINI_API_KEY") ): return False if model.startswith("groq/") and not os.environ.get("GROQ_API_KEY"): return False return True def _make_llm(model: str, tools_list: list): return ChatLiteLLM(model=model, temperature=0).bind_tools(tools_list) def _is_transient(e: Exception) -> bool: name = type(e).__name__ msg = str(e).lower() return ( any(t in name for t in _TRANSIENT_ERROR_NAMES) or "rate limit" in msg or "timeout" in msg or "connection" in msg or "503" in msg or "502" in msg or "529" in msg ) def _call_with_retry(model: str, msgs: list, max_retries: int, base_delay: float): for attempt in range(max_retries + 1): try: llm = _make_llm(model, _runtime.tools) return llm.invoke(msgs) except Exception as e: if attempt >= max_retries or not _is_transient(e): raise delay = min(base_delay * (2.0**attempt), 30.0) print( f"[RETRY] {model} attempt {attempt + 1}/{max_retries} failed ({type(e).__name__}). Retrying in {delay:.1f}s..." ) time.sleep(delay) raise RuntimeError("Unreachable") def _extract_usage(response) -> Optional[dict]: usage = getattr(response, "usage_metadata", None) or getattr( response, "response_metadata", {} ).get("usage", None) if usage is None: return None input_tokens = ( getattr(usage, "prompt_token_count", None) or getattr(usage, "input_tokens", None) or (usage.get("prompt_tokens") if isinstance(usage, dict) else None) or 0 ) output_tokens = ( getattr(usage, "candidates_token_count", None) or getattr(usage, "output_tokens", None) or (usage.get("completion_tokens") if isinstance(usage, dict) else None) or 0 ) return {"input": input_tokens, "output": output_tokens, "unit": "TOKENS"} def invoke_agent( system_prompt: str, state: dict, node_name: str, *, extra_messages: Optional[list] = None, context_window: int = 10, ) -> dict: cost_tracker = _runtime.cost_tracker circuit_breaker = _runtime.circuit_breaker circuit_events = _runtime.circuit_events export_ui_state_fn = _runtime.export_ui_state_fn langfuse = get_langfuse() trace_id = state.get("langfuse_trace_id") if langfuse.is_enabled() and not trace_id: trace = langfuse.create_trace( name=f"auto-swe-agent", metadata={ "task": state.get("current_task", "unknown")[:200], "workspace": state.get("workspace_dir", "unknown"), "mode": "multi-agent", }, ) if trace is not None and hasattr(trace, "id"): trace_id = trace.id trimmed = [] for msg in state["messages"][-context_window:]: if ( hasattr(msg, "content") and isinstance(msg.content, str) and len(msg.content) > 4000 ): from langchain_core.messages import ToolMessage if isinstance(msg, ToolMessage): msg = ToolMessage( content=msg.content[:4000] + "\n[TRUNCATED]", tool_call_id=msg.tool_call_id, ) trimmed.append(msg) msgs = [SystemMessage(content=system_prompt)] + (extra_messages or []) + trimmed last_input = trimmed[-1].content if trimmed else "" agent_span = None if langfuse.is_enabled() and trace_id: agent_span = langfuse.span( trace_id=trace_id, name=f"agent-{node_name}", input={ "messages_count": len(msgs), "context_window": context_window, "last_input_preview": str(last_input)[:300], }, ) for model in _runtime.fallback_models: if not _model_available(model): print(f"[SKIP] {model} — no API key set.") continue if not circuit_breaker.can_call(model): event = f"[CIRCUIT OPEN] Skipping {model} (cooldown active)" print(event) circuit_events.append(event) continue print(f"\n--- [NODE] {node_name.upper()} | model={model} ---") try: response = _call_with_retry( model, msgs, max_retries=state.get("_retry_max", 3), base_delay=state.get("_retry_delay", 2.0), ) circuit_breaker.record_success(model) usage_dict = _extract_usage(response) if usage_dict: input_tokens = usage_dict["input"] output_tokens = usage_dict["output"] estimated = False else: input_tokens = len(msgs) * 500 output_tokens = len(str(response.content)) // 4 estimated = True print( f"[COST] Token counts unavailable — using estimates (in={input_tokens}, out={output_tokens})" ) # Langfuse generation trace if langfuse.is_enabled() and trace_id: gen_params = { "trace_id": trace_id, "name": f"llm-{node_name}", "model": model, "input": str(last_input)[:500], "output": ( str(response.content)[:1000] if hasattr(response, "content") else str(response)[:1000] ), } if usage_dict: gen_params["usage"] = usage_dict langfuse.generation(**gen_params) cost_tracker.add_call( model, input_tokens, output_tokens, node_name, estimated ) total_cost = cost_tracker.get_total_cost() print( f"[COST] ${total_cost:.6f} total | this call: in={input_tokens} out={output_tokens} tokens" ) if agent_span is not None: agent_span.update(output={"status": "success", "model_used": model}) if cost_tracker.check_budget_exceeded(): print( f"[COST] Budget exceeded (${total_cost:.4f} > ${cost_tracker.budget_usd})." ) budget_msg = SystemMessage( content=f"Budget exceeded (${total_cost:.4f} > ${cost_tracker.budget_usd}). Halting." ) result = { "messages": [response, budget_msg], "iteration_count": state["iteration_count"] + 1, "total_cost_usd": total_cost, "budget_exceeded": True, "tests_passed": False, "current_node": node_name, "current_agent": node_name, } if trace_id: result["langfuse_trace_id"] = trace_id if export_ui_state_fn: export_ui_state_fn({**state, **result}, node_name) return result result = { "messages": [response], "iteration_count": state["iteration_count"] + 1, "total_cost_usd": total_cost, "budget_exceeded": False, "current_node": node_name, "current_agent": node_name, } if trace_id: result["langfuse_trace_id"] = trace_id if export_ui_state_fn: export_ui_state_fn({**state, **result}, node_name) return result except Exception as e: err_name = type(e).__name__ is_permanent = ( any(t in err_name for t in _SKIP_ERRORS) or "Missing" in str(e) or "key" in str(e).lower() ) if not is_permanent: circuit_breaker.record_failure(model) status = circuit_breaker.get_status().get(model, {}) if status.get("state") == "open": event = f"[CIRCUIT OPENED] {model} after {status.get('failures')} failures" circuit_events.append(event) print(f"[FALLBACK] {model} failed: {err_name}. Trying next model...") if agent_span is not None: agent_span.update( output={"status": "fallback", "error": err_name, "model": model} ) continue if agent_span is not None: agent_span.update(output={"status": "error", "error": "All models exhausted"}) raise RuntimeError("All models in fallback chain exhausted.")