from __future__ import annotations from dataclasses import asdict from .base import BaseAgent from ..config import TrainConfig from ..ml.dataset import Dataset, make_synthetic_binary_classification, load_csv_dataset class DataAgent(BaseAgent): name = "DataAgent" def run(self, *, cfg: TrainConfig) -> Dataset: self.emit(receiver="MessageBus", kind="stage", payload={"stage": "data", "cfg": cfg.to_dict()}) if cfg.csv_path: ds = load_csv_dataset( csv_path=cfg.csv_path, target_col=cfg.target_col, seed=cfg.seed, train_ratio=cfg.train_ratio, ) # Update config based on loaded data n_features = len(ds.feature_names) if ds.feature_names else 0 cfg_dict = cfg.to_dict() cfg_dict["n_features"] = n_features cfg_dict["n_samples"] = len(ds.x_train) + len(ds.x_val) else: ds = make_synthetic_binary_classification( seed=cfg.seed, n_samples=cfg.n_samples, n_features=cfg.n_features, train_ratio=cfg.train_ratio, ) n_features = cfg.n_features self.ctx.store.write_json( "dataset.json", { "n_train": len(ds.x_train), "n_val": len(ds.x_val), "n_features": n_features, "label_rate_train": sum(ds.y_train) / max(1, len(ds.y_train)), "label_rate_val": sum(ds.y_val) / max(1, len(ds.y_val)), "feature_names": ds.feature_names }, ) self.ctx.store.write_json("true_params.json", {"true_w": ds.true_w, "true_b": ds.true_b}) self.emit( receiver="Orchestrator", kind="data_ready", payload={"n_train": len(ds.x_train), "n_val": len(ds.x_val), "n_features": n_features}, ) return ds