3v324v23's picture
feat: 实现多智能体协作的机器学习训练与调试闭环
11ac7be
from __future__ import annotations
from typing import Any, Dict, List
from .base import BaseAgent
from ..config import TrainConfig
from ..ml.dataset import Dataset
from ..ml.train import train_logreg_sgd
class TrainerAgent(BaseAgent):
name = "TrainerAgent"
def run(self, *, cfg: TrainConfig, dataset: Dataset, attempt: int) -> List[Dict[str, Any]]:
self.emit(
receiver="MessageBus",
kind="stage",
payload={"stage": "train", "attempt": attempt, "cfg": cfg.to_dict()},
)
res = train_logreg_sgd(
seed=cfg.seed + attempt,
xs_train=dataset.x_train,
ys_train=dataset.y_train,
xs_val=dataset.x_val,
ys_val=dataset.y_val,
epochs=cfg.epochs,
lr=cfg.lr,
l2=cfg.l2,
grad_clip=cfg.grad_clip,
loss_eps=cfg.loss_eps,
)
self.ctx.store.write_json(
f"model.attempt{attempt}.json",
{"w": res.model.w, "b": res.model.b, "attempt": attempt, "cfg": cfg.to_dict()},
)
rows: List[Dict[str, Any]] = []
for row in res.history:
payload = dict(row)
payload["attempt"] = attempt
payload["lr"] = float(cfg.lr)
payload["l2"] = float(cfg.l2)
payload["grad_clip"] = None if cfg.grad_clip is None else float(cfg.grad_clip)
payload["loss_eps"] = float(cfg.loss_eps)
rows.append(payload)
self.ctx.store.append_jsonl("metrics.jsonl", payload)
last = rows[-1] if rows else {}
self.emit(
receiver="Orchestrator",
kind="train_finished",
payload={
"attempt": attempt,
"epochs_ran": len(rows),
"last": last,
},
)
return rows