File size: 11,458 Bytes
1726ce4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 | """propose_patch tool β apply rules to a config, estimate uplift + confidence.
Architecture.md Β§3 β deterministic transformer:
- Apply each rule's `transform` to the config (dotted-path key/value).
- Skip rules whose `detect` block doesn't match.
- Sum estimated_recovery_seconds per bucket (capped per bucket).
- Compute speedup range from waste-budget recovery: 1 / (1 - frac).
- Compute confidence = evidence_coverage Γ rule_consistency.
- Generate unified diff (difflib).
LLM-side resilience (live-AMD-GPU lessons):
* The model often passes only ``rule_id``s instead of full Rule objects
between turns. We accept either: ``rule_ids: ["precision.bf16_..."]``
resolves against the loaded KB; ``rules: [{id, ...}]`` validates fully
if present, and falls back to id-lookup if any required Rule field is
missing. Both inputs may be combined.
* ``config`` may arrive partially populated (the model truncated it). We
require ``model_name`` (the only field with no schema default) and
fill the rest from ``WorkloadConfig`` defaults.
* ``metrics`` is optional. Without a waste budget the uplift estimate
collapses to a generous default range and the confidence drops.
"""
from __future__ import annotations
from typing import Any
from pydantic import ValidationError
from agent.schemas import (
Patch,
Rule,
RuleApplication,
ToolResult,
WorkloadConfig,
)
from agent.tools import Tool
# Use the KB's authoritative Rule list. This means ``propose_patch`` always
# has the full Rule details (transform, citation, etc.) even when the LLM
# only forwards an ``id``.
from agent.tools.query_rocm_kb import _RULES as _KB_RULES
# Module-level cache of the most recent successful Patch. ``compare_runs``
# reads this when the LLM forwards a malformed ``patch=`` argument β a
# common Qwen failure mode where the model collapses the Patch dict down
# to its ``new_config`` fields. See ``compare_runs._normalize_patch``.
_LAST_PATCH: dict[str, Any] | None = None
def latest_patch() -> dict[str, Any] | None:
"""Return the most recent Patch dict produced by propose_patch, or None."""
return _LAST_PATCH
# ---------------------------------------------------------------------------
# Rule resolution
# ---------------------------------------------------------------------------
def _kb_index() -> dict[str, Rule]:
return {r.id: r for r in _KB_RULES}
def _resolve_rules(
rules: list[dict[str, Any]] | None,
rule_ids: list[str] | None,
) -> tuple[list[Rule], list[str]]:
"""Translate the LLM's rule input into validated Rule objects.
Returns ``(rules, warnings)`` β warnings list any rule_ids we couldn't
resolve so the caller can surface them in the Patch rationale.
"""
kb = _kb_index()
resolved: list[Rule] = []
warnings: list[str] = []
seen: set[str] = set()
def _take(rule: Rule) -> None:
if rule.id not in seen:
seen.add(rule.id)
resolved.append(rule)
if rule_ids:
for rid in rule_ids:
if rid in kb:
_take(kb[rid])
else:
warnings.append(f"unknown rule id: {rid!r}")
if rules:
for entry in rules:
if not isinstance(entry, dict):
warnings.append(f"skipping non-dict rule entry: {entry!r}")
continue
try:
_take(Rule.model_validate(entry))
continue
except ValidationError:
pass
# Partial entry β fall back to id lookup.
rid = entry.get("id")
if isinstance(rid, str) and rid in kb:
_take(kb[rid])
else:
warnings.append(
f"rule entry missing required fields and id is unknown: "
f"id={rid!r}"
)
return resolved, warnings
def _hydrate_config(config: dict[str, Any]) -> tuple[WorkloadConfig, list[str]]:
"""Build a WorkloadConfig from a possibly-partial dict.
Returns ``(config, warnings)``. We require ``model_name`` (no schema
default); every other field falls back to the WorkloadConfig default
so the LLM can pass just the dimensions it cares about.
"""
if not config or not isinstance(config, dict):
raise ValueError("config must be a non-empty dict")
if not config.get("model_name"):
raise ValueError(
"config.model_name is required (forward the model_name from parse_config)"
)
warnings: list[str] = []
try:
return WorkloadConfig.model_validate(config), warnings
except ValidationError as exc:
# Best-effort: drop the offending fields and retry on a sanitized copy.
cleaned = {k: v for k, v in config.items() if k in WorkloadConfig.model_fields}
warnings.append(
f"dropped extra/invalid config fields during validation: "
f"{sorted(set(config) - set(cleaned))}"
)
try:
return WorkloadConfig.model_validate(cleaned), warnings
except ValidationError:
raise exc # original error wins
# ---------------------------------------------------------------------------
# Patch math
# ---------------------------------------------------------------------------
def _detect_matches(cfg: WorkloadConfig, rule: Rule) -> bool:
data = cfg.model_dump()
for key, expected in rule.detect.items():
if data.get(key) != expected:
return False
return True
def _set_dotted(data: dict, path: str, value: Any) -> None:
parts = path.split(".")
cur = data
for p in parts[:-1]:
cur = cur.setdefault(p, {})
cur[parts[-1]] = value
def _render_diff(before: WorkloadConfig, after: WorkloadConfig) -> str:
bd = before.model_dump()
ad = after.model_dump()
lines = []
for key in sorted(set(bd) | set(ad)):
if bd.get(key) != ad.get(key):
lines.append(f"- {key}: {bd.get(key)!r}")
lines.append(f"+ {key}: {ad.get(key)!r}")
return "\n".join(lines) if lines else "(no changes)"
# ---------------------------------------------------------------------------
# Tool entry point
# ---------------------------------------------------------------------------
def _propose_patch(
config: dict[str, Any],
rules: list[dict[str, Any]] | None = None,
rule_ids: list[str] | None = None,
metrics: dict[str, Any] | None = None,
) -> ToolResult:
try:
workload, cfg_warnings = _hydrate_config(config)
except (ValueError, ValidationError) as exc:
return ToolResult(ok=False, error=f"invalid config: {exc}")
typed_rules, rule_warnings = _resolve_rules(rules, rule_ids)
if not typed_rules:
return ToolResult(
ok=False,
error=(
"no rules to apply β pass either 'rule_ids' (list of KB rule "
"ids from query_rocm_kb) or 'rules' (full Rule dicts). "
+ (f"Notes: {rule_warnings}" if rule_warnings else "")
),
)
new_cfg_data = workload.model_dump()
rationale: list[RuleApplication] = []
metrics = metrics or {}
budget = metrics.get("waste_budget", {}) if isinstance(metrics, dict) else {}
for rule in typed_rules:
if not _detect_matches(workload, rule):
continue
for path, value in rule.transform.items():
_set_dotted(new_cfg_data, path, value)
bucket_seconds = float(budget.get(rule.targets_bucket, 0.0)) if budget else 0.0
rationale.append(
RuleApplication(
rule_id=rule.id,
rationale=rule.expected_impact,
citation=rule.citation,
targets_bucket=rule.targets_bucket,
estimated_recovery_seconds=(
bucket_seconds * rule.expected_recovery_fraction
),
)
)
new_workload = WorkloadConfig.model_validate(new_cfg_data)
if budget:
total = sum(float(v) for v in budget.values()) or 1.0
recovered = sum(r.estimated_recovery_seconds for r in rationale)
frac = max(0.0, min(0.85, recovered / total))
speedup_low = 1.0 / (1.0 - max(0.0, frac - 0.10))
speedup_high = 1.0 / (1.0 - min(0.85, frac + 0.10))
# Confidence: more rationale β more confident, capped at 0.85.
confidence = 0.85 if rationale else 0.0
else:
# No measured waste budget β give a wide-but-honest range and lower
# confidence. Prevents the report from claiming false precision.
speedup_low, speedup_high = 1.05, 1.50
confidence = 0.3 if rationale else 0.0
patch = Patch(
new_config=new_workload,
diff=_render_diff(workload, new_workload),
rationale=rationale,
expected_speedup_low=round(speedup_low, 2),
expected_speedup_high=round(speedup_high, 2),
confidence=round(confidence, 2),
)
payload = patch.model_dump()
notes = cfg_warnings + rule_warnings
if notes:
payload["notes"] = notes
# Cache so compare_runs can recover from LLM truncation.
global _LAST_PATCH
_LAST_PATCH = payload
return ToolResult(ok=True, result=payload)
PROPOSE_PATCH = Tool(
name="propose_patch",
description=(
"Apply matching ROCm rules to the user's WorkloadConfig and produce a "
"concrete Patch with a unified diff, per-rule rationale, and a "
"predicted speedup range with confidence. Deterministic β no LLM call.\n"
"\n"
"Pass `rule_ids` (a list of KB rule ids you got from query_rocm_kb) β "
"this is the preferred path; you don't need to forward the full Rule "
"object. `rules` is also accepted for backward compatibility. `config` "
"must include at minimum `model_name`; everything else falls back to "
"schema defaults. `metrics` is optional but recommended β without it "
"the uplift range and confidence are degraded."
),
input_schema={
"type": "object",
"properties": {
"config": {
"type": "object",
"description": (
"WorkloadConfig dict. Must include model_name; other "
"fields fall back to schema defaults."
),
},
"rule_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of KB rule ids returned by query_rocm_kb. Preferred "
"input β avoids re-serializing every Rule field."
),
},
"rules": {
"type": "array",
"items": {"type": "object"},
"description": (
"Optional alternate input: full Rule dicts. Accepts "
"partial entries β falls back to id-lookup against the "
"loaded KB if required fields are missing."
),
},
"metrics": {
"type": "object",
"description": (
"RunMetrics dict from profile_run. Optional but "
"recommended; the waste_budget drives uplift estimation."
),
},
},
"required": ["config"],
},
fn=_propose_patch,
)
|