File size: 11,458 Bytes
1726ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""propose_patch tool β€” apply rules to a config, estimate uplift + confidence.

Architecture.md Β§3 β€” deterministic transformer:
  - Apply each rule's `transform` to the config (dotted-path key/value).
  - Skip rules whose `detect` block doesn't match.
  - Sum estimated_recovery_seconds per bucket (capped per bucket).
  - Compute speedup range from waste-budget recovery: 1 / (1 - frac).
  - Compute confidence = evidence_coverage Γ— rule_consistency.
  - Generate unified diff (difflib).

LLM-side resilience (live-AMD-GPU lessons):

* The model often passes only ``rule_id``s instead of full Rule objects
  between turns. We accept either: ``rule_ids: ["precision.bf16_..."]``
  resolves against the loaded KB; ``rules: [{id, ...}]`` validates fully
  if present, and falls back to id-lookup if any required Rule field is
  missing. Both inputs may be combined.

* ``config`` may arrive partially populated (the model truncated it). We
  require ``model_name`` (the only field with no schema default) and
  fill the rest from ``WorkloadConfig`` defaults.

* ``metrics`` is optional. Without a waste budget the uplift estimate
  collapses to a generous default range and the confidence drops.
"""

from __future__ import annotations

from typing import Any

from pydantic import ValidationError

from agent.schemas import (
    Patch,
    Rule,
    RuleApplication,
    ToolResult,
    WorkloadConfig,
)
from agent.tools import Tool

# Use the KB's authoritative Rule list. This means ``propose_patch`` always
# has the full Rule details (transform, citation, etc.) even when the LLM
# only forwards an ``id``.
from agent.tools.query_rocm_kb import _RULES as _KB_RULES


# Module-level cache of the most recent successful Patch. ``compare_runs``
# reads this when the LLM forwards a malformed ``patch=`` argument β€” a
# common Qwen failure mode where the model collapses the Patch dict down
# to its ``new_config`` fields. See ``compare_runs._normalize_patch``.
_LAST_PATCH: dict[str, Any] | None = None


def latest_patch() -> dict[str, Any] | None:
    """Return the most recent Patch dict produced by propose_patch, or None."""
    return _LAST_PATCH


# ---------------------------------------------------------------------------
# Rule resolution
# ---------------------------------------------------------------------------


def _kb_index() -> dict[str, Rule]:
    return {r.id: r for r in _KB_RULES}


def _resolve_rules(
    rules: list[dict[str, Any]] | None,
    rule_ids: list[str] | None,
) -> tuple[list[Rule], list[str]]:
    """Translate the LLM's rule input into validated Rule objects.

    Returns ``(rules, warnings)`` β€” warnings list any rule_ids we couldn't
    resolve so the caller can surface them in the Patch rationale.
    """
    kb = _kb_index()
    resolved: list[Rule] = []
    warnings: list[str] = []
    seen: set[str] = set()

    def _take(rule: Rule) -> None:
        if rule.id not in seen:
            seen.add(rule.id)
            resolved.append(rule)

    if rule_ids:
        for rid in rule_ids:
            if rid in kb:
                _take(kb[rid])
            else:
                warnings.append(f"unknown rule id: {rid!r}")

    if rules:
        for entry in rules:
            if not isinstance(entry, dict):
                warnings.append(f"skipping non-dict rule entry: {entry!r}")
                continue
            try:
                _take(Rule.model_validate(entry))
                continue
            except ValidationError:
                pass
            # Partial entry β€” fall back to id lookup.
            rid = entry.get("id")
            if isinstance(rid, str) and rid in kb:
                _take(kb[rid])
            else:
                warnings.append(
                    f"rule entry missing required fields and id is unknown: "
                    f"id={rid!r}"
                )

    return resolved, warnings


def _hydrate_config(config: dict[str, Any]) -> tuple[WorkloadConfig, list[str]]:
    """Build a WorkloadConfig from a possibly-partial dict.

    Returns ``(config, warnings)``. We require ``model_name`` (no schema
    default); every other field falls back to the WorkloadConfig default
    so the LLM can pass just the dimensions it cares about.
    """
    if not config or not isinstance(config, dict):
        raise ValueError("config must be a non-empty dict")
    if not config.get("model_name"):
        raise ValueError(
            "config.model_name is required (forward the model_name from parse_config)"
        )
    warnings: list[str] = []
    try:
        return WorkloadConfig.model_validate(config), warnings
    except ValidationError as exc:
        # Best-effort: drop the offending fields and retry on a sanitized copy.
        cleaned = {k: v for k, v in config.items() if k in WorkloadConfig.model_fields}
        warnings.append(
            f"dropped extra/invalid config fields during validation: "
            f"{sorted(set(config) - set(cleaned))}"
        )
        try:
            return WorkloadConfig.model_validate(cleaned), warnings
        except ValidationError:
            raise exc  # original error wins


# ---------------------------------------------------------------------------
# Patch math
# ---------------------------------------------------------------------------


def _detect_matches(cfg: WorkloadConfig, rule: Rule) -> bool:
    data = cfg.model_dump()
    for key, expected in rule.detect.items():
        if data.get(key) != expected:
            return False
    return True


def _set_dotted(data: dict, path: str, value: Any) -> None:
    parts = path.split(".")
    cur = data
    for p in parts[:-1]:
        cur = cur.setdefault(p, {})
    cur[parts[-1]] = value


def _render_diff(before: WorkloadConfig, after: WorkloadConfig) -> str:
    bd = before.model_dump()
    ad = after.model_dump()
    lines = []
    for key in sorted(set(bd) | set(ad)):
        if bd.get(key) != ad.get(key):
            lines.append(f"- {key}: {bd.get(key)!r}")
            lines.append(f"+ {key}: {ad.get(key)!r}")
    return "\n".join(lines) if lines else "(no changes)"


# ---------------------------------------------------------------------------
# Tool entry point
# ---------------------------------------------------------------------------


def _propose_patch(
    config: dict[str, Any],
    rules: list[dict[str, Any]] | None = None,
    rule_ids: list[str] | None = None,
    metrics: dict[str, Any] | None = None,
) -> ToolResult:
    try:
        workload, cfg_warnings = _hydrate_config(config)
    except (ValueError, ValidationError) as exc:
        return ToolResult(ok=False, error=f"invalid config: {exc}")

    typed_rules, rule_warnings = _resolve_rules(rules, rule_ids)
    if not typed_rules:
        return ToolResult(
            ok=False,
            error=(
                "no rules to apply β€” pass either 'rule_ids' (list of KB rule "
                "ids from query_rocm_kb) or 'rules' (full Rule dicts). "
                + (f"Notes: {rule_warnings}" if rule_warnings else "")
            ),
        )

    new_cfg_data = workload.model_dump()
    rationale: list[RuleApplication] = []
    metrics = metrics or {}
    budget = metrics.get("waste_budget", {}) if isinstance(metrics, dict) else {}

    for rule in typed_rules:
        if not _detect_matches(workload, rule):
            continue
        for path, value in rule.transform.items():
            _set_dotted(new_cfg_data, path, value)
        bucket_seconds = float(budget.get(rule.targets_bucket, 0.0)) if budget else 0.0
        rationale.append(
            RuleApplication(
                rule_id=rule.id,
                rationale=rule.expected_impact,
                citation=rule.citation,
                targets_bucket=rule.targets_bucket,
                estimated_recovery_seconds=(
                    bucket_seconds * rule.expected_recovery_fraction
                ),
            )
        )

    new_workload = WorkloadConfig.model_validate(new_cfg_data)

    if budget:
        total = sum(float(v) for v in budget.values()) or 1.0
        recovered = sum(r.estimated_recovery_seconds for r in rationale)
        frac = max(0.0, min(0.85, recovered / total))
        speedup_low = 1.0 / (1.0 - max(0.0, frac - 0.10))
        speedup_high = 1.0 / (1.0 - min(0.85, frac + 0.10))
        # Confidence: more rationale β†’ more confident, capped at 0.85.
        confidence = 0.85 if rationale else 0.0
    else:
        # No measured waste budget β€” give a wide-but-honest range and lower
        # confidence. Prevents the report from claiming false precision.
        speedup_low, speedup_high = 1.05, 1.50
        confidence = 0.3 if rationale else 0.0

    patch = Patch(
        new_config=new_workload,
        diff=_render_diff(workload, new_workload),
        rationale=rationale,
        expected_speedup_low=round(speedup_low, 2),
        expected_speedup_high=round(speedup_high, 2),
        confidence=round(confidence, 2),
    )
    payload = patch.model_dump()
    notes = cfg_warnings + rule_warnings
    if notes:
        payload["notes"] = notes
    # Cache so compare_runs can recover from LLM truncation.
    global _LAST_PATCH
    _LAST_PATCH = payload
    return ToolResult(ok=True, result=payload)


PROPOSE_PATCH = Tool(
    name="propose_patch",
    description=(
        "Apply matching ROCm rules to the user's WorkloadConfig and produce a "
        "concrete Patch with a unified diff, per-rule rationale, and a "
        "predicted speedup range with confidence. Deterministic β€” no LLM call.\n"
        "\n"
        "Pass `rule_ids` (a list of KB rule ids you got from query_rocm_kb) β€” "
        "this is the preferred path; you don't need to forward the full Rule "
        "object. `rules` is also accepted for backward compatibility. `config` "
        "must include at minimum `model_name`; everything else falls back to "
        "schema defaults. `metrics` is optional but recommended β€” without it "
        "the uplift range and confidence are degraded."
    ),
    input_schema={
        "type": "object",
        "properties": {
            "config": {
                "type": "object",
                "description": (
                    "WorkloadConfig dict. Must include model_name; other "
                    "fields fall back to schema defaults."
                ),
            },
            "rule_ids": {
                "type": "array",
                "items": {"type": "string"},
                "description": (
                    "List of KB rule ids returned by query_rocm_kb. Preferred "
                    "input β€” avoids re-serializing every Rule field."
                ),
            },
            "rules": {
                "type": "array",
                "items": {"type": "object"},
                "description": (
                    "Optional alternate input: full Rule dicts. Accepts "
                    "partial entries β€” falls back to id-lookup against the "
                    "loaded KB if required fields are missing."
                ),
            },
            "metrics": {
                "type": "object",
                "description": (
                    "RunMetrics dict from profile_run. Optional but "
                    "recommended; the waste_budget drives uplift estimation."
                ),
            },
        },
        "required": ["config"],
    },
    fn=_propose_patch,
)