Spaces:
Running
Running
| from __future__ import annotations | |
| import math | |
| from dataclasses import replace | |
| from typing import Any, Dict, List | |
| from .base import BaseAgent | |
| from ..config import TrainConfig | |
| class DebuggerAgent(BaseAgent): | |
| name = "DebuggerAgent" | |
| def run( | |
| self, | |
| *, | |
| cfg: TrainConfig, | |
| metrics: List[Dict[str, Any]], | |
| eval_result: Dict[str, Any], | |
| attempt: int, | |
| ) -> Dict[str, Any]: | |
| self.emit( | |
| receiver="MessageBus", | |
| kind="stage", | |
| payload={"stage": "debug", "attempt": attempt}, | |
| ) | |
| val_acc = float(eval_result["val_acc"]) | |
| val_loss = float(eval_result["val_loss"]) | |
| unstable = (not math.isfinite(val_loss)) or any( | |
| (not math.isfinite(float(m.get("val_loss", 0.0)))) or (not math.isfinite(float(m.get("train_loss", 0.0)))) | |
| for m in metrics[-3:] | |
| ) | |
| last_three = metrics[-3:] if len(metrics) >= 3 else metrics | |
| increasing = False | |
| if len(last_three) >= 3: | |
| v0 = float(last_three[0]["val_loss"]) | |
| v1 = float(last_three[1]["val_loss"]) | |
| v2 = float(last_three[2]["val_loss"]) | |
| increasing = math.isfinite(v0) and math.isfinite(v1) and math.isfinite(v2) and (v2 > v1 > v0) | |
| accept = (not unstable) and val_acc >= 0.82 | |
| if accept: | |
| decision = {"action": "accept", "reason": "验证集指标达到阈值且训练稳定。"} | |
| self.ctx.store.write_json(f"debug.attempt{attempt}.json", decision) | |
| self.emit(receiver="Orchestrator", kind="debug_decision", payload=decision) | |
| return decision | |
| next_cfg = cfg | |
| reasons: List[str] = [] | |
| if unstable: | |
| next_cfg = replace( | |
| next_cfg, | |
| lr=max(1e-4, next_cfg.lr * 0.1), | |
| grad_clip=5.0 if next_cfg.grad_clip is None else min(5.0, next_cfg.grad_clip), | |
| l2=max(next_cfg.l2, 1e-4), | |
| loss_eps=1e-12 if next_cfg.loss_eps <= 0 else next_cfg.loss_eps, | |
| ) | |
| reasons.append("检测到 loss 非有限值/疑似发散:降低学习率、启用梯度裁剪、加入轻微 L2。") | |
| elif increasing: | |
| next_cfg = replace( | |
| next_cfg, | |
| lr=max(1e-4, next_cfg.lr * 0.5), | |
| l2=max(next_cfg.l2, 5e-4), | |
| ) | |
| reasons.append("验证集 loss 连续上升:降低学习率、增强正则。") | |
| else: | |
| next_cfg = replace(next_cfg, lr=max(1e-4, next_cfg.lr * 0.7), l2=max(next_cfg.l2, 1e-4)) | |
| reasons.append("未达标:轻微降学习率并加入正则,提升收敛与泛化稳定性。") | |
| decision = { | |
| "action": "retry", | |
| "reason": " ".join(reasons), | |
| "next_cfg": next_cfg.to_dict(), | |
| "observed": {"val_acc": val_acc, "val_loss": val_loss}, | |
| } | |
| self.ctx.store.write_json(f"debug.attempt{attempt}.json", decision) | |
| self.emit(receiver="Orchestrator", kind="debug_decision", payload=decision) | |
| return decision | |