File size: 9,009 Bytes
3fc99cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""Agent loop driver — provider-agnostic tool-use loop for one audit.

`run_audit(file_path)` is an async generator that yields `SSEEvent` objects in
the order the UI should render them: thoughts, tool calls, tool results, and
finally either a `final_report` event (extracted from the most recent
successful `compare_runs` tool result) or an `error` event.

The loop itself doesn't know about Anthropic or Hugging Face — it talks to
whichever `Backend` `make_backend()` returns. The backend (Claude or Qwen-HF
today) handles all per-API translation. See `agent/backends/__init__.py`.
"""

from __future__ import annotations

import json
from collections.abc import AsyncIterator
from typing import Any

from agent import tools as tools_module
from agent.backends import Backend, ToolCall, make_backend
from agent.prompts import SYSTEM_PROMPT
from agent.schemas import SSEEvent

MAX_STEPS = 10
"""Hard cap on tool calls per audit. The canonical trajectory is six calls
(parse → profile → query_kb → patch → benchmark×2 → compare). The extra
4 calls of headroom let the model recover from common mistakes (JSON
nesting glitches, retry on ToolResult(ok=False)) without exhausting the
budget before compare_runs. Was 8; bumped after a live run hit a wall when
two misnested-arg benchmark retries ate the slack meant for compare_runs.
"""
MAX_TOKENS = 2048


def _extract_final_report(
    tool_results: list[dict[str, Any]],
) -> dict[str, Any] | None:
    """Walk tool results in reverse and return the most recent successful
    compare_runs payload, or None if there isn't one."""
    for entry in reversed(tool_results):
        if entry["name"] == "compare_runs" and entry["ok"]:
            return entry["result"]
    return None


def _auto_compare(
    tool_results: list[dict[str, Any]],
) -> dict[str, Any] | None:
    """Synthesize a Report from whatever the audit produced when the model
    didn't reach `compare_runs` cleanly. Three recovery tiers, in order of
    fidelity:

    Tier 1 — full data: ≥2 benchmarks + ≥1 propose_patch.
        Treat first benchmark as baseline, last as patched run. Highest
        fidelity since both numbers are real.

    Tier 2 — patch but only one benchmark: ≥1 patch + 1 benchmark.
        Use the single benchmark as baseline. For the "after" side, run
        FakeRunner on the patched config to get a deterministic projection.
        Marks the report as projected so the demo is honest about it.

    Tier 3 — no patch ran but we have rules from query_rocm_kb + ≥1 benchmark.
        We *could* deterministically apply propose_patch ourselves here, but
        that's over-reaching. Return None and let the caller surface a
        clean error instead.

    Returns the Report dict, or None when no tier applies.
    """
    benchmarks = [
        e for e in tool_results if e["name"] == "benchmark" and e["ok"]
    ]
    patches = [
        e for e in tool_results if e["name"] == "propose_patch" and e["ok"]
    ]

    # Tier 1: full data path.
    if len(benchmarks) >= 2 and patches:
        latest_patch = patches[-1]["result"]
        before = benchmarks[0]["result"]
        after = benchmarks[-1]["result"]
        return _call_compare_runs(latest_patch, before, after, " (auto-synthesized compare_runs)")

    # Tier 2: patch + 1 benchmark — fill in the patched-side metrics from
    # FakeRunner so the demo still produces a Report with a clear note.
    if patches and len(benchmarks) == 1:
        latest_patch = patches[-1]["result"]
        before = benchmarks[0]["result"]
        # Project the patched run via FakeRunner. The synthetic corpus has
        # a `02_optimized` scenario the patched config typically matches.
        from agent.schemas import WorkloadConfig
        from runner.protocol import FakeRunner

        try:
            patched_cfg = WorkloadConfig.model_validate(latest_patch["new_config"])
            after_metrics = FakeRunner().run(patched_cfg, steps=before.get("steps", 50))
            after = after_metrics.model_dump()
        except Exception:
            return None
        return _call_compare_runs(
            latest_patch,
            before,
            after,
            " (auto-synthesized; patched-side projected via FakeRunner)",
        )

    return None


def _call_compare_runs(
    patch: dict[str, Any],
    before: dict[str, Any],
    after: dict[str, Any],
    suffix: str,
) -> dict[str, Any] | None:
    workload_name = (
        patch.get("new_config", {}).get("model_name")
        or "Audited Workload"
    ) + suffix
    result = tools_module.call(
        "compare_runs",
        workload_name=workload_name,
        before=before,
        after=after,
        patch=patch,
    )
    return result.result if result.ok else None


def _safe_json(value: Any) -> str:
    """Serialize a tool result for inclusion in a tool_result content block.

    Falls back to ``str(value)`` if json can't represent the value (e.g. a
    Pydantic model already coerced upstream — shouldn't happen, but defensive).
    """
    try:
        return json.dumps(value, default=str)
    except Exception:
        return str(value)


async def _drive(backend: Backend) -> AsyncIterator[SSEEvent]:
    """Pure orchestration loop. Backend handles per-API state; we yield events."""
    tool_results_log: list[dict[str, Any]] = []

    for _step in range(MAX_STEPS):
        turn = await backend.next_turn(tools_module.tool_schemas())

        for text in turn.text_blocks:
            if text:
                yield SSEEvent(type="thought", data={"text": text})

        for tc in turn.tool_calls:
            async for ev in _execute_tool_call(backend, tc, tool_results_log):
                yield ev

        if turn.stop_reason == "end_turn":
            break

    report = _extract_final_report(tool_results_log)
    if report is not None:
        yield SSEEvent(type="final_report", data={"report": report})
        return

    # Fallback: the model didn't call compare_runs (or its tool_call landed
    # inside a thinking block where the parser couldn't extract it).
    # Synthesize the report deterministically from the tool log if we have
    # enough material. See _auto_compare for the prerequisites.
    auto = _auto_compare(tool_results_log)
    if auto is not None:
        yield SSEEvent(
            type="thought",
            data={
                "text": (
                    "Note: model did not emit a compare_runs tool call (likely "
                    "left it inside a <think> block). Synthesizing the final "
                    "report from the latest propose_patch + two benchmarks."
                )
            },
        )
        yield SSEEvent(type="final_report", data={"report": auto})
        return

    yield SSEEvent(
        type="error",
        data={
            "message": (
                "Audit completed without producing a final report (and "
                "auto-synthesis fallback couldn't run — need at least one "
                "successful propose_patch and two successful benchmarks)."
            )
        },
    )


async def _execute_tool_call(
    backend: Backend,
    tc: ToolCall,
    tool_results_log: list[dict[str, Any]],
) -> AsyncIterator[SSEEvent]:
    """Yield the tool_call/tool_result event pair and record the outcome."""
    yield SSEEvent(
        type="tool_call",
        data={"id": tc.id, "name": tc.name, "input": tc.input},
    )

    result = tools_module.call(tc.name, **tc.input)

    yield SSEEvent(
        type="tool_result",
        data={
            "id": tc.id,
            "name": tc.name,
            "ok": result.ok,
            "result": result.result,
            "error": result.error,
        },
    )

    tool_results_log.append(
        {
            "id": tc.id,
            "name": tc.name,
            "ok": result.ok,
            "result": result.result,
            "error": result.error,
        }
    )

    content = (
        _safe_json(result.result) if result.ok else (result.error or "tool failed")
    )
    backend.add_tool_result(
        tool_call_id=tc.id,
        name=tc.name,
        content=content,
        is_error=not result.ok,
    )


async def run_audit(file_path: str) -> AsyncIterator[SSEEvent]:
    """Run one audit and yield SSE events as they happen.

    Selects the LLM backend from the `GOBLIN_AGENT_BACKEND` env var (defaults
    to `claude`; `qwen` routes through HF Inference Providers). On any
    backend or loop exception, yields a single `error` SSE event and stops.
    """
    try:
        backend = make_backend(system_prompt=SYSTEM_PROMPT, max_tokens=MAX_TOKENS)
    except Exception as exc:
        yield SSEEvent(type="error", data={"message": str(exc)})
        return

    backend.add_user_message(f"Audit this fine-tuning workload: {file_path}")

    try:
        async for ev in _drive(backend):
            yield ev
    except Exception as exc:
        yield SSEEvent(type="error", data={"message": str(exc)})