Spaces:
Running
Running
| from __future__ import annotations | |
| from typing import Any, Dict, List | |
| from .base import BaseAgent | |
| from ..config import TrainConfig | |
| from ..ml.dataset import Dataset | |
| from ..ml.train import train_logreg_sgd | |
| class TrainerAgent(BaseAgent): | |
| name = "TrainerAgent" | |
| def run(self, *, cfg: TrainConfig, dataset: Dataset, attempt: int) -> List[Dict[str, Any]]: | |
| self.emit( | |
| receiver="MessageBus", | |
| kind="stage", | |
| payload={"stage": "train", "attempt": attempt, "cfg": cfg.to_dict()}, | |
| ) | |
| res = train_logreg_sgd( | |
| seed=cfg.seed + attempt, | |
| xs_train=dataset.x_train, | |
| ys_train=dataset.y_train, | |
| xs_val=dataset.x_val, | |
| ys_val=dataset.y_val, | |
| epochs=cfg.epochs, | |
| lr=cfg.lr, | |
| l2=cfg.l2, | |
| grad_clip=cfg.grad_clip, | |
| loss_eps=cfg.loss_eps, | |
| ) | |
| self.ctx.store.write_json( | |
| f"model.attempt{attempt}.json", | |
| {"w": res.model.w, "b": res.model.b, "attempt": attempt, "cfg": cfg.to_dict()}, | |
| ) | |
| rows: List[Dict[str, Any]] = [] | |
| for row in res.history: | |
| payload = dict(row) | |
| payload["attempt"] = attempt | |
| payload["lr"] = float(cfg.lr) | |
| payload["l2"] = float(cfg.l2) | |
| payload["grad_clip"] = None if cfg.grad_clip is None else float(cfg.grad_clip) | |
| payload["loss_eps"] = float(cfg.loss_eps) | |
| rows.append(payload) | |
| self.ctx.store.append_jsonl("metrics.jsonl", payload) | |
| last = rows[-1] if rows else {} | |
| self.emit( | |
| receiver="Orchestrator", | |
| kind="train_finished", | |
| payload={ | |
| "attempt": attempt, | |
| "epochs_ran": len(rows), | |
| "last": last, | |
| }, | |
| ) | |
| return rows | |