File size: 9,009 Bytes
3fc99cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | """Agent loop driver — provider-agnostic tool-use loop for one audit.
`run_audit(file_path)` is an async generator that yields `SSEEvent` objects in
the order the UI should render them: thoughts, tool calls, tool results, and
finally either a `final_report` event (extracted from the most recent
successful `compare_runs` tool result) or an `error` event.
The loop itself doesn't know about Anthropic or Hugging Face — it talks to
whichever `Backend` `make_backend()` returns. The backend (Claude or Qwen-HF
today) handles all per-API translation. See `agent/backends/__init__.py`.
"""
from __future__ import annotations
import json
from collections.abc import AsyncIterator
from typing import Any
from agent import tools as tools_module
from agent.backends import Backend, ToolCall, make_backend
from agent.prompts import SYSTEM_PROMPT
from agent.schemas import SSEEvent
MAX_STEPS = 10
"""Hard cap on tool calls per audit. The canonical trajectory is six calls
(parse → profile → query_kb → patch → benchmark×2 → compare). The extra
4 calls of headroom let the model recover from common mistakes (JSON
nesting glitches, retry on ToolResult(ok=False)) without exhausting the
budget before compare_runs. Was 8; bumped after a live run hit a wall when
two misnested-arg benchmark retries ate the slack meant for compare_runs.
"""
MAX_TOKENS = 2048
def _extract_final_report(
tool_results: list[dict[str, Any]],
) -> dict[str, Any] | None:
"""Walk tool results in reverse and return the most recent successful
compare_runs payload, or None if there isn't one."""
for entry in reversed(tool_results):
if entry["name"] == "compare_runs" and entry["ok"]:
return entry["result"]
return None
def _auto_compare(
tool_results: list[dict[str, Any]],
) -> dict[str, Any] | None:
"""Synthesize a Report from whatever the audit produced when the model
didn't reach `compare_runs` cleanly. Three recovery tiers, in order of
fidelity:
Tier 1 — full data: ≥2 benchmarks + ≥1 propose_patch.
Treat first benchmark as baseline, last as patched run. Highest
fidelity since both numbers are real.
Tier 2 — patch but only one benchmark: ≥1 patch + 1 benchmark.
Use the single benchmark as baseline. For the "after" side, run
FakeRunner on the patched config to get a deterministic projection.
Marks the report as projected so the demo is honest about it.
Tier 3 — no patch ran but we have rules from query_rocm_kb + ≥1 benchmark.
We *could* deterministically apply propose_patch ourselves here, but
that's over-reaching. Return None and let the caller surface a
clean error instead.
Returns the Report dict, or None when no tier applies.
"""
benchmarks = [
e for e in tool_results if e["name"] == "benchmark" and e["ok"]
]
patches = [
e for e in tool_results if e["name"] == "propose_patch" and e["ok"]
]
# Tier 1: full data path.
if len(benchmarks) >= 2 and patches:
latest_patch = patches[-1]["result"]
before = benchmarks[0]["result"]
after = benchmarks[-1]["result"]
return _call_compare_runs(latest_patch, before, after, " (auto-synthesized compare_runs)")
# Tier 2: patch + 1 benchmark — fill in the patched-side metrics from
# FakeRunner so the demo still produces a Report with a clear note.
if patches and len(benchmarks) == 1:
latest_patch = patches[-1]["result"]
before = benchmarks[0]["result"]
# Project the patched run via FakeRunner. The synthetic corpus has
# a `02_optimized` scenario the patched config typically matches.
from agent.schemas import WorkloadConfig
from runner.protocol import FakeRunner
try:
patched_cfg = WorkloadConfig.model_validate(latest_patch["new_config"])
after_metrics = FakeRunner().run(patched_cfg, steps=before.get("steps", 50))
after = after_metrics.model_dump()
except Exception:
return None
return _call_compare_runs(
latest_patch,
before,
after,
" (auto-synthesized; patched-side projected via FakeRunner)",
)
return None
def _call_compare_runs(
patch: dict[str, Any],
before: dict[str, Any],
after: dict[str, Any],
suffix: str,
) -> dict[str, Any] | None:
workload_name = (
patch.get("new_config", {}).get("model_name")
or "Audited Workload"
) + suffix
result = tools_module.call(
"compare_runs",
workload_name=workload_name,
before=before,
after=after,
patch=patch,
)
return result.result if result.ok else None
def _safe_json(value: Any) -> str:
"""Serialize a tool result for inclusion in a tool_result content block.
Falls back to ``str(value)`` if json can't represent the value (e.g. a
Pydantic model already coerced upstream — shouldn't happen, but defensive).
"""
try:
return json.dumps(value, default=str)
except Exception:
return str(value)
async def _drive(backend: Backend) -> AsyncIterator[SSEEvent]:
"""Pure orchestration loop. Backend handles per-API state; we yield events."""
tool_results_log: list[dict[str, Any]] = []
for _step in range(MAX_STEPS):
turn = await backend.next_turn(tools_module.tool_schemas())
for text in turn.text_blocks:
if text:
yield SSEEvent(type="thought", data={"text": text})
for tc in turn.tool_calls:
async for ev in _execute_tool_call(backend, tc, tool_results_log):
yield ev
if turn.stop_reason == "end_turn":
break
report = _extract_final_report(tool_results_log)
if report is not None:
yield SSEEvent(type="final_report", data={"report": report})
return
# Fallback: the model didn't call compare_runs (or its tool_call landed
# inside a thinking block where the parser couldn't extract it).
# Synthesize the report deterministically from the tool log if we have
# enough material. See _auto_compare for the prerequisites.
auto = _auto_compare(tool_results_log)
if auto is not None:
yield SSEEvent(
type="thought",
data={
"text": (
"Note: model did not emit a compare_runs tool call (likely "
"left it inside a <think> block). Synthesizing the final "
"report from the latest propose_patch + two benchmarks."
)
},
)
yield SSEEvent(type="final_report", data={"report": auto})
return
yield SSEEvent(
type="error",
data={
"message": (
"Audit completed without producing a final report (and "
"auto-synthesis fallback couldn't run — need at least one "
"successful propose_patch and two successful benchmarks)."
)
},
)
async def _execute_tool_call(
backend: Backend,
tc: ToolCall,
tool_results_log: list[dict[str, Any]],
) -> AsyncIterator[SSEEvent]:
"""Yield the tool_call/tool_result event pair and record the outcome."""
yield SSEEvent(
type="tool_call",
data={"id": tc.id, "name": tc.name, "input": tc.input},
)
result = tools_module.call(tc.name, **tc.input)
yield SSEEvent(
type="tool_result",
data={
"id": tc.id,
"name": tc.name,
"ok": result.ok,
"result": result.result,
"error": result.error,
},
)
tool_results_log.append(
{
"id": tc.id,
"name": tc.name,
"ok": result.ok,
"result": result.result,
"error": result.error,
}
)
content = (
_safe_json(result.result) if result.ok else (result.error or "tool failed")
)
backend.add_tool_result(
tool_call_id=tc.id,
name=tc.name,
content=content,
is_error=not result.ok,
)
async def run_audit(file_path: str) -> AsyncIterator[SSEEvent]:
"""Run one audit and yield SSE events as they happen.
Selects the LLM backend from the `GOBLIN_AGENT_BACKEND` env var (defaults
to `claude`; `qwen` routes through HF Inference Providers). On any
backend or loop exception, yields a single `error` SSE event and stops.
"""
try:
backend = make_backend(system_prompt=SYSTEM_PROMPT, max_tokens=MAX_TOKENS)
except Exception as exc:
yield SSEEvent(type="error", data={"message": str(exc)})
return
backend.add_user_message(f"Audit this fine-tuning workload: {file_path}")
try:
async for ev in _drive(backend):
yield ev
except Exception as exc:
yield SSEEvent(type="error", data={"message": str(exc)})
|