3v324v23's picture
feat: 实现多智能体协作的机器学习训练与调试闭环
11ac7be
from __future__ import annotations
import math
from dataclasses import replace
from typing import Any, Dict, List
from .base import BaseAgent
from ..config import TrainConfig
class DebuggerAgent(BaseAgent):
name = "DebuggerAgent"
def run(
self,
*,
cfg: TrainConfig,
metrics: List[Dict[str, Any]],
eval_result: Dict[str, Any],
attempt: int,
) -> Dict[str, Any]:
self.emit(
receiver="MessageBus",
kind="stage",
payload={"stage": "debug", "attempt": attempt},
)
val_acc = float(eval_result["val_acc"])
val_loss = float(eval_result["val_loss"])
unstable = (not math.isfinite(val_loss)) or any(
(not math.isfinite(float(m.get("val_loss", 0.0)))) or (not math.isfinite(float(m.get("train_loss", 0.0))))
for m in metrics[-3:]
)
last_three = metrics[-3:] if len(metrics) >= 3 else metrics
increasing = False
if len(last_three) >= 3:
v0 = float(last_three[0]["val_loss"])
v1 = float(last_three[1]["val_loss"])
v2 = float(last_three[2]["val_loss"])
increasing = math.isfinite(v0) and math.isfinite(v1) and math.isfinite(v2) and (v2 > v1 > v0)
accept = (not unstable) and val_acc >= 0.82
if accept:
decision = {"action": "accept", "reason": "验证集指标达到阈值且训练稳定。"}
self.ctx.store.write_json(f"debug.attempt{attempt}.json", decision)
self.emit(receiver="Orchestrator", kind="debug_decision", payload=decision)
return decision
next_cfg = cfg
reasons: List[str] = []
if unstable:
next_cfg = replace(
next_cfg,
lr=max(1e-4, next_cfg.lr * 0.1),
grad_clip=5.0 if next_cfg.grad_clip is None else min(5.0, next_cfg.grad_clip),
l2=max(next_cfg.l2, 1e-4),
loss_eps=1e-12 if next_cfg.loss_eps <= 0 else next_cfg.loss_eps,
)
reasons.append("检测到 loss 非有限值/疑似发散:降低学习率、启用梯度裁剪、加入轻微 L2。")
elif increasing:
next_cfg = replace(
next_cfg,
lr=max(1e-4, next_cfg.lr * 0.5),
l2=max(next_cfg.l2, 5e-4),
)
reasons.append("验证集 loss 连续上升:降低学习率、增强正则。")
else:
next_cfg = replace(next_cfg, lr=max(1e-4, next_cfg.lr * 0.7), l2=max(next_cfg.l2, 1e-4))
reasons.append("未达标:轻微降学习率并加入正则,提升收敛与泛化稳定性。")
decision = {
"action": "retry",
"reason": " ".join(reasons),
"next_cfg": next_cfg.to_dict(),
"observed": {"val_acc": val_acc, "val_loss": val_loss},
}
self.ctx.store.write_json(f"debug.attempt{attempt}.json", decision)
self.emit(receiver="Orchestrator", kind="debug_decision", payload=decision)
return decision