File size: 3,111 Bytes
11ac7be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import annotations

import math
from dataclasses import replace
from typing import Any, Dict, List

from .base import BaseAgent
from ..config import TrainConfig


class DebuggerAgent(BaseAgent):
    name = "DebuggerAgent"

    def run(
        self,
        *,
        cfg: TrainConfig,
        metrics: List[Dict[str, Any]],
        eval_result: Dict[str, Any],
        attempt: int,
    ) -> Dict[str, Any]:
        self.emit(
            receiver="MessageBus",
            kind="stage",
            payload={"stage": "debug", "attempt": attempt},
        )

        val_acc = float(eval_result["val_acc"])
        val_loss = float(eval_result["val_loss"])

        unstable = (not math.isfinite(val_loss)) or any(
            (not math.isfinite(float(m.get("val_loss", 0.0)))) or (not math.isfinite(float(m.get("train_loss", 0.0))))
            for m in metrics[-3:]
        )

        last_three = metrics[-3:] if len(metrics) >= 3 else metrics
        increasing = False
        if len(last_three) >= 3:
            v0 = float(last_three[0]["val_loss"])
            v1 = float(last_three[1]["val_loss"])
            v2 = float(last_three[2]["val_loss"])
            increasing = math.isfinite(v0) and math.isfinite(v1) and math.isfinite(v2) and (v2 > v1 > v0)

        accept = (not unstable) and val_acc >= 0.82
        if accept:
            decision = {"action": "accept", "reason": "验证集指标达到阈值且训练稳定。"}
            self.ctx.store.write_json(f"debug.attempt{attempt}.json", decision)
            self.emit(receiver="Orchestrator", kind="debug_decision", payload=decision)
            return decision

        next_cfg = cfg
        reasons: List[str] = []

        if unstable:
            next_cfg = replace(
                next_cfg,
                lr=max(1e-4, next_cfg.lr * 0.1),
                grad_clip=5.0 if next_cfg.grad_clip is None else min(5.0, next_cfg.grad_clip),
                l2=max(next_cfg.l2, 1e-4),
                loss_eps=1e-12 if next_cfg.loss_eps <= 0 else next_cfg.loss_eps,
            )
            reasons.append("检测到 loss 非有限值/疑似发散:降低学习率、启用梯度裁剪、加入轻微 L2。")
        elif increasing:
            next_cfg = replace(
                next_cfg,
                lr=max(1e-4, next_cfg.lr * 0.5),
                l2=max(next_cfg.l2, 5e-4),
            )
            reasons.append("验证集 loss 连续上升:降低学习率、增强正则。")
        else:
            next_cfg = replace(next_cfg, lr=max(1e-4, next_cfg.lr * 0.7), l2=max(next_cfg.l2, 1e-4))
            reasons.append("未达标:轻微降学习率并加入正则,提升收敛与泛化稳定性。")

        decision = {
            "action": "retry",
            "reason": " ".join(reasons),
            "next_cfg": next_cfg.to_dict(),
            "observed": {"val_acc": val_acc, "val_loss": val_loss},
        }
        self.ctx.store.write_json(f"debug.attempt{attempt}.json", decision)
        self.emit(receiver="Orchestrator", kind="debug_decision", payload=decision)
        return decision