from __future__ import annotations from typing import Any, Dict, List from .base import BaseAgent from ..config import TrainConfig from ..ml.dataset import Dataset from ..ml.train import train_logreg_sgd class TrainerAgent(BaseAgent): name = "TrainerAgent" def run(self, *, cfg: TrainConfig, dataset: Dataset, attempt: int) -> List[Dict[str, Any]]: self.emit( receiver="MessageBus", kind="stage", payload={"stage": "train", "attempt": attempt, "cfg": cfg.to_dict()}, ) res = train_logreg_sgd( seed=cfg.seed + attempt, xs_train=dataset.x_train, ys_train=dataset.y_train, xs_val=dataset.x_val, ys_val=dataset.y_val, epochs=cfg.epochs, lr=cfg.lr, l2=cfg.l2, grad_clip=cfg.grad_clip, loss_eps=cfg.loss_eps, ) self.ctx.store.write_json( f"model.attempt{attempt}.json", {"w": res.model.w, "b": res.model.b, "attempt": attempt, "cfg": cfg.to_dict()}, ) rows: List[Dict[str, Any]] = [] for row in res.history: payload = dict(row) payload["attempt"] = attempt payload["lr"] = float(cfg.lr) payload["l2"] = float(cfg.l2) payload["grad_clip"] = None if cfg.grad_clip is None else float(cfg.grad_clip) payload["loss_eps"] = float(cfg.loss_eps) rows.append(payload) self.ctx.store.append_jsonl("metrics.jsonl", payload) last = rows[-1] if rows else {} self.emit( receiver="Orchestrator", kind="train_finished", payload={ "attempt": attempt, "epochs_ran": len(rows), "last": last, }, ) return rows