File size: 10,486 Bytes
6085b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import os
import time
from typing import Any, Callable, Optional

from langchain_community.chat_models import ChatLiteLLM
from langchain_core.messages import SystemMessage

from observability.langfuse_client import get_langfuse

_SKIP_ERRORS = (
    "ResourceExhausted",
    "RateLimit",
    "QuotaExceeded",
    "APIConnectionError",
    "AuthenticationError",
    "BadRequestError",
)
_TRANSIENT_ERROR_NAMES = (
    "RateLimitError",
    "ResourceExhausted",
    "APIConnectionError",
    "Timeout",
    "ConnectionError",
    "ServiceUnavailable",
    "InternalServerError",
)
_TRANSIENT_EXCEPTIONS = (Exception,)


class AgentRuntime:
    cost_tracker: Any = None
    circuit_breaker: Any = None
    circuit_events: list = []
    fallback_models: list[str] = []
    tools: list = []
    executor_node: Any = None
    export_ui_state_fn: Optional[Callable] = None


_runtime = AgentRuntime()


def configure_runtime(
    cost_tracker,
    circuit_breaker,
    circuit_events: list,
    fallback_models: list[str],
    tools: list,
    executor_node,
    export_ui_state_fn: Optional[Callable] = None,
) -> None:
    _runtime.cost_tracker = cost_tracker
    _runtime.circuit_breaker = circuit_breaker
    _runtime.circuit_events = circuit_events
    _runtime.fallback_models = fallback_models
    _runtime.tools = tools
    _runtime.executor_node = executor_node
    _runtime.export_ui_state_fn = export_ui_state_fn


def get_runtime() -> AgentRuntime:
    return _runtime


def _model_available(model: str) -> bool:
    if (
        model.startswith("gemini/")
        and not os.environ.get("GOOGLE_API_KEY")
        and not os.environ.get("GEMINI_API_KEY")
    ):
        return False
    if model.startswith("groq/") and not os.environ.get("GROQ_API_KEY"):
        return False
    return True


def _make_llm(model: str, tools_list: list):
    return ChatLiteLLM(model=model, temperature=0).bind_tools(tools_list)


def _is_transient(e: Exception) -> bool:
    name = type(e).__name__
    msg = str(e).lower()
    return (
        any(t in name for t in _TRANSIENT_ERROR_NAMES)
        or "rate limit" in msg
        or "timeout" in msg
        or "connection" in msg
        or "503" in msg
        or "502" in msg
        or "529" in msg
    )


def _call_with_retry(model: str, msgs: list, max_retries: int, base_delay: float):
    for attempt in range(max_retries + 1):
        try:
            llm = _make_llm(model, _runtime.tools)
            return llm.invoke(msgs)
        except Exception as e:
            if attempt >= max_retries or not _is_transient(e):
                raise
            delay = min(base_delay * (2.0**attempt), 30.0)
            print(
                f"[RETRY] {model} attempt {attempt + 1}/{max_retries} failed ({type(e).__name__}). Retrying in {delay:.1f}s..."
            )
            time.sleep(delay)
    raise RuntimeError("Unreachable")


def _extract_usage(response) -> Optional[dict]:
    usage = getattr(response, "usage_metadata", None) or getattr(
        response, "response_metadata", {}
    ).get("usage", None)
    if usage is None:
        return None
    input_tokens = (
        getattr(usage, "prompt_token_count", None)
        or getattr(usage, "input_tokens", None)
        or (usage.get("prompt_tokens") if isinstance(usage, dict) else None)
        or 0
    )
    output_tokens = (
        getattr(usage, "candidates_token_count", None)
        or getattr(usage, "output_tokens", None)
        or (usage.get("completion_tokens") if isinstance(usage, dict) else None)
        or 0
    )
    return {"input": input_tokens, "output": output_tokens, "unit": "TOKENS"}


def invoke_agent(
    system_prompt: str,
    state: dict,
    node_name: str,
    *,
    extra_messages: Optional[list] = None,
    context_window: int = 10,
) -> dict:
    cost_tracker = _runtime.cost_tracker
    circuit_breaker = _runtime.circuit_breaker
    circuit_events = _runtime.circuit_events
    export_ui_state_fn = _runtime.export_ui_state_fn

    langfuse = get_langfuse()
    trace_id = state.get("langfuse_trace_id")

    if langfuse.is_enabled() and not trace_id:
        trace = langfuse.create_trace(
            name=f"auto-swe-agent",
            metadata={
                "task": state.get("current_task", "unknown")[:200],
                "workspace": state.get("workspace_dir", "unknown"),
                "mode": "multi-agent",
            },
        )
        if trace is not None and hasattr(trace, "id"):
            trace_id = trace.id

    trimmed = []
    for msg in state["messages"][-context_window:]:
        if (
            hasattr(msg, "content")
            and isinstance(msg.content, str)
            and len(msg.content) > 4000
        ):
            from langchain_core.messages import ToolMessage

            if isinstance(msg, ToolMessage):
                msg = ToolMessage(
                    content=msg.content[:4000] + "\n[TRUNCATED]",
                    tool_call_id=msg.tool_call_id,
                )
        trimmed.append(msg)

    msgs = [SystemMessage(content=system_prompt)] + (extra_messages or []) + trimmed
    last_input = trimmed[-1].content if trimmed else ""

    agent_span = None
    if langfuse.is_enabled() and trace_id:
        agent_span = langfuse.span(
            trace_id=trace_id,
            name=f"agent-{node_name}",
            input={
                "messages_count": len(msgs),
                "context_window": context_window,
                "last_input_preview": str(last_input)[:300],
            },
        )

    for model in _runtime.fallback_models:
        if not _model_available(model):
            print(f"[SKIP] {model} — no API key set.")
            continue

        if not circuit_breaker.can_call(model):
            event = f"[CIRCUIT OPEN] Skipping {model} (cooldown active)"
            print(event)
            circuit_events.append(event)
            continue

        print(f"\n--- [NODE] {node_name.upper()} | model={model} ---")

        try:
            response = _call_with_retry(
                model,
                msgs,
                max_retries=state.get("_retry_max", 3),
                base_delay=state.get("_retry_delay", 2.0),
            )
            circuit_breaker.record_success(model)

            usage_dict = _extract_usage(response)
            if usage_dict:
                input_tokens = usage_dict["input"]
                output_tokens = usage_dict["output"]
                estimated = False
            else:
                input_tokens = len(msgs) * 500
                output_tokens = len(str(response.content)) // 4
                estimated = True
                print(
                    f"[COST] Token counts unavailable — using estimates (in={input_tokens}, out={output_tokens})"
                )

            # Langfuse generation trace
            if langfuse.is_enabled() and trace_id:
                gen_params = {
                    "trace_id": trace_id,
                    "name": f"llm-{node_name}",
                    "model": model,
                    "input": str(last_input)[:500],
                    "output": (
                        str(response.content)[:1000]
                        if hasattr(response, "content")
                        else str(response)[:1000]
                    ),
                }
                if usage_dict:
                    gen_params["usage"] = usage_dict
                langfuse.generation(**gen_params)

            cost_tracker.add_call(
                model, input_tokens, output_tokens, node_name, estimated
            )
            total_cost = cost_tracker.get_total_cost()
            print(
                f"[COST] ${total_cost:.6f} total | this call: in={input_tokens} out={output_tokens} tokens"
            )

            if agent_span is not None:
                agent_span.update(output={"status": "success", "model_used": model})

            if cost_tracker.check_budget_exceeded():
                print(
                    f"[COST] Budget exceeded (${total_cost:.4f} > ${cost_tracker.budget_usd})."
                )
                budget_msg = SystemMessage(
                    content=f"Budget exceeded (${total_cost:.4f} > ${cost_tracker.budget_usd}). Halting."
                )
                result = {
                    "messages": [response, budget_msg],
                    "iteration_count": state["iteration_count"] + 1,
                    "total_cost_usd": total_cost,
                    "budget_exceeded": True,
                    "tests_passed": False,
                    "current_node": node_name,
                    "current_agent": node_name,
                }
                if trace_id:
                    result["langfuse_trace_id"] = trace_id
                if export_ui_state_fn:
                    export_ui_state_fn({**state, **result}, node_name)
                return result

            result = {
                "messages": [response],
                "iteration_count": state["iteration_count"] + 1,
                "total_cost_usd": total_cost,
                "budget_exceeded": False,
                "current_node": node_name,
                "current_agent": node_name,
            }
            if trace_id:
                result["langfuse_trace_id"] = trace_id
            if export_ui_state_fn:
                export_ui_state_fn({**state, **result}, node_name)
            return result

        except Exception as e:
            err_name = type(e).__name__
            is_permanent = (
                any(t in err_name for t in _SKIP_ERRORS)
                or "Missing" in str(e)
                or "key" in str(e).lower()
            )
            if not is_permanent:
                circuit_breaker.record_failure(model)
                status = circuit_breaker.get_status().get(model, {})
                if status.get("state") == "open":
                    event = f"[CIRCUIT OPENED] {model} after {status.get('failures')} failures"
                    circuit_events.append(event)
            print(f"[FALLBACK] {model} failed: {err_name}. Trying next model...")
            if agent_span is not None:
                agent_span.update(
                    output={"status": "fallback", "error": err_name, "model": model}
                )
            continue

    if agent_span is not None:
        agent_span.update(output={"status": "error", "error": "All models exhausted"})
    raise RuntimeError("All models in fallback chain exhausted.")