Spaces:
Running
Running
| from __future__ import annotations | |
| from dataclasses import asdict | |
| from .base import BaseAgent | |
| from ..config import TrainConfig | |
| from ..ml.dataset import Dataset, make_synthetic_binary_classification, load_csv_dataset | |
| class DataAgent(BaseAgent): | |
| name = "DataAgent" | |
| def run(self, *, cfg: TrainConfig) -> Dataset: | |
| self.emit(receiver="MessageBus", kind="stage", payload={"stage": "data", "cfg": cfg.to_dict()}) | |
| if cfg.csv_path: | |
| ds = load_csv_dataset( | |
| csv_path=cfg.csv_path, | |
| target_col=cfg.target_col, | |
| seed=cfg.seed, | |
| train_ratio=cfg.train_ratio, | |
| ) | |
| # Update config based on loaded data | |
| n_features = len(ds.feature_names) if ds.feature_names else 0 | |
| cfg_dict = cfg.to_dict() | |
| cfg_dict["n_features"] = n_features | |
| cfg_dict["n_samples"] = len(ds.x_train) + len(ds.x_val) | |
| else: | |
| ds = make_synthetic_binary_classification( | |
| seed=cfg.seed, | |
| n_samples=cfg.n_samples, | |
| n_features=cfg.n_features, | |
| train_ratio=cfg.train_ratio, | |
| ) | |
| n_features = cfg.n_features | |
| self.ctx.store.write_json( | |
| "dataset.json", | |
| { | |
| "n_train": len(ds.x_train), | |
| "n_val": len(ds.x_val), | |
| "n_features": n_features, | |
| "label_rate_train": sum(ds.y_train) / max(1, len(ds.y_train)), | |
| "label_rate_val": sum(ds.y_val) / max(1, len(ds.y_val)), | |
| "feature_names": ds.feature_names | |
| }, | |
| ) | |
| self.ctx.store.write_json("true_params.json", {"true_w": ds.true_w, "true_b": ds.true_b}) | |
| self.emit( | |
| receiver="Orchestrator", | |
| kind="data_ready", | |
| payload={"n_train": len(ds.x_train), "n_val": len(ds.x_val), "n_features": n_features}, | |
| ) | |
| return ds | |