File size: 1,973 Bytes
11ac7be
 
 
 
 
 
0bb564a
11ac7be
 
 
 
 
 
 
0bb564a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11ac7be
 
 
 
 
 
0bb564a
11ac7be
 
0bb564a
11ac7be
 
 
 
 
 
 
0bb564a
11ac7be
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from __future__ import annotations

from dataclasses import asdict

from .base import BaseAgent
from ..config import TrainConfig
from ..ml.dataset import Dataset, make_synthetic_binary_classification, load_csv_dataset


class DataAgent(BaseAgent):
    name = "DataAgent"

    def run(self, *, cfg: TrainConfig) -> Dataset:
        self.emit(receiver="MessageBus", kind="stage", payload={"stage": "data", "cfg": cfg.to_dict()})
        
        if cfg.csv_path:
            ds = load_csv_dataset(
                csv_path=cfg.csv_path,
                target_col=cfg.target_col,
                seed=cfg.seed,
                train_ratio=cfg.train_ratio,
            )
            # Update config based on loaded data
            n_features = len(ds.feature_names) if ds.feature_names else 0
            cfg_dict = cfg.to_dict()
            cfg_dict["n_features"] = n_features
            cfg_dict["n_samples"] = len(ds.x_train) + len(ds.x_val)
        else:
            ds = make_synthetic_binary_classification(
                seed=cfg.seed,
                n_samples=cfg.n_samples,
                n_features=cfg.n_features,
                train_ratio=cfg.train_ratio,
            )
            n_features = cfg.n_features

        self.ctx.store.write_json(
            "dataset.json",
            {
                "n_train": len(ds.x_train),
                "n_val": len(ds.x_val),
                "n_features": n_features,
                "label_rate_train": sum(ds.y_train) / max(1, len(ds.y_train)),
                "label_rate_val": sum(ds.y_val) / max(1, len(ds.y_val)),
                "feature_names": ds.feature_names
            },
        )
        self.ctx.store.write_json("true_params.json", {"true_w": ds.true_w, "true_b": ds.true_b})

        self.emit(
            receiver="Orchestrator",
            kind="data_ready",
            payload={"n_train": len(ds.x_train), "n_val": len(ds.x_val), "n_features": n_features},
        )
        return ds