from __future__ import annotations import math from dataclasses import replace from typing import Any, Dict, List from .base import BaseAgent from ..config import TrainConfig class DebuggerAgent(BaseAgent): name = "DebuggerAgent" def run( self, *, cfg: TrainConfig, metrics: List[Dict[str, Any]], eval_result: Dict[str, Any], attempt: int, ) -> Dict[str, Any]: self.emit( receiver="MessageBus", kind="stage", payload={"stage": "debug", "attempt": attempt}, ) val_acc = float(eval_result["val_acc"]) val_loss = float(eval_result["val_loss"]) unstable = (not math.isfinite(val_loss)) or any( (not math.isfinite(float(m.get("val_loss", 0.0)))) or (not math.isfinite(float(m.get("train_loss", 0.0)))) for m in metrics[-3:] ) last_three = metrics[-3:] if len(metrics) >= 3 else metrics increasing = False if len(last_three) >= 3: v0 = float(last_three[0]["val_loss"]) v1 = float(last_three[1]["val_loss"]) v2 = float(last_three[2]["val_loss"]) increasing = math.isfinite(v0) and math.isfinite(v1) and math.isfinite(v2) and (v2 > v1 > v0) accept = (not unstable) and val_acc >= 0.82 if accept: decision = {"action": "accept", "reason": "验证集指标达到阈值且训练稳定。"} self.ctx.store.write_json(f"debug.attempt{attempt}.json", decision) self.emit(receiver="Orchestrator", kind="debug_decision", payload=decision) return decision next_cfg = cfg reasons: List[str] = [] if unstable: next_cfg = replace( next_cfg, lr=max(1e-4, next_cfg.lr * 0.1), grad_clip=5.0 if next_cfg.grad_clip is None else min(5.0, next_cfg.grad_clip), l2=max(next_cfg.l2, 1e-4), loss_eps=1e-12 if next_cfg.loss_eps <= 0 else next_cfg.loss_eps, ) reasons.append("检测到 loss 非有限值/疑似发散:降低学习率、启用梯度裁剪、加入轻微 L2。") elif increasing: next_cfg = replace( next_cfg, lr=max(1e-4, next_cfg.lr * 0.5), l2=max(next_cfg.l2, 5e-4), ) reasons.append("验证集 loss 连续上升:降低学习率、增强正则。") else: next_cfg = replace(next_cfg, lr=max(1e-4, next_cfg.lr * 0.7), l2=max(next_cfg.l2, 1e-4)) reasons.append("未达标:轻微降学习率并加入正则,提升收敛与泛化稳定性。") decision = { "action": "retry", "reason": " ".join(reasons), "next_cfg": next_cfg.to_dict(), "observed": {"val_acc": val_acc, "val_loss": val_loss}, } self.ctx.store.write_json(f"debug.attempt{attempt}.json", decision) self.emit(receiver="Orchestrator", kind="debug_decision", payload=decision) return decision