new-human
feat: 支持真实CSV数据训练与模型推理闭环
0bb564a
from __future__ import annotations
from dataclasses import asdict
from .base import BaseAgent
from ..config import TrainConfig
from ..ml.dataset import Dataset, make_synthetic_binary_classification, load_csv_dataset
class DataAgent(BaseAgent):
name = "DataAgent"
def run(self, *, cfg: TrainConfig) -> Dataset:
self.emit(receiver="MessageBus", kind="stage", payload={"stage": "data", "cfg": cfg.to_dict()})
if cfg.csv_path:
ds = load_csv_dataset(
csv_path=cfg.csv_path,
target_col=cfg.target_col,
seed=cfg.seed,
train_ratio=cfg.train_ratio,
)
# Update config based on loaded data
n_features = len(ds.feature_names) if ds.feature_names else 0
cfg_dict = cfg.to_dict()
cfg_dict["n_features"] = n_features
cfg_dict["n_samples"] = len(ds.x_train) + len(ds.x_val)
else:
ds = make_synthetic_binary_classification(
seed=cfg.seed,
n_samples=cfg.n_samples,
n_features=cfg.n_features,
train_ratio=cfg.train_ratio,
)
n_features = cfg.n_features
self.ctx.store.write_json(
"dataset.json",
{
"n_train": len(ds.x_train),
"n_val": len(ds.x_val),
"n_features": n_features,
"label_rate_train": sum(ds.y_train) / max(1, len(ds.y_train)),
"label_rate_val": sum(ds.y_val) / max(1, len(ds.y_val)),
"feature_names": ds.feature_names
},
)
self.ctx.store.write_json("true_params.json", {"true_w": ds.true_w, "true_b": ds.true_b})
self.emit(
receiver="Orchestrator",
kind="data_ready",
payload={"n_train": len(ds.x_train), "n_val": len(ds.x_val), "n_features": n_features},
)
return ds