File size: 3,379 Bytes
11ac7be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d536096
11ac7be
 
 
 
d536096
11ac7be
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from __future__ import annotations

from dataclasses import replace
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

from .artifacts import ArtifactStore
from .bus import MessageBus
from .config import TrainConfig
from .agents.data_agent import DataAgent
from .agents.debugger_agent import DebuggerAgent
from .agents.evaluator_agent import EvaluatorAgent
from .agents.reporter_agent import ReporterAgent
from .agents.trainer_agent import TrainerAgent


class Orchestrator:
    def __init__(self, *, base_runs_dir: Path) -> None:
        self.base_runs_dir = base_runs_dir

    def run(self, *, cfg: TrainConfig, scenario: str) -> Tuple[ArtifactStore, Dict[str, Any]]:
        store = ArtifactStore.create(self.base_runs_dir)
        bus = MessageBus()

        bus.send(sender="Orchestrator", receiver="All", kind="run_started", payload={"scenario": scenario})

        data_agent = DataAgent(bus=bus, store=store)
        trainer_agent = TrainerAgent(bus=bus, store=store)
        evaluator_agent = EvaluatorAgent(bus=bus, store=store)
        debugger_agent = DebuggerAgent(bus=bus, store=store)
        reporter_agent = ReporterAgent(bus=bus, store=store)

        dataset = data_agent.run(cfg=cfg)

        effective_cfg = cfg
        if scenario == "unstable":
            effective_cfg = replace(
                effective_cfg,
                lr=max(6.0, effective_cfg.lr * 30),
                grad_clip=None,
                l2=0.0,
                loss_eps=0.0,
            )
        elif scenario != "stable":
            raise ValueError(f"未知 scenario: {scenario}")

        store.write_json("config.initial.json", effective_cfg.to_dict())

        summary: Optional[Dict[str, Any]] = None
        attempt = 0
        max_attempts = 3
        while attempt < max_attempts:
            attempt += 1
            bus.send(
                sender="Orchestrator",
                receiver="All",
                kind="attempt_started",
                payload={"attempt": attempt, "cfg": effective_cfg.to_dict()},
            )

            metrics = trainer_agent.run(cfg=effective_cfg, dataset=dataset, attempt=attempt)
            eval_result = evaluator_agent.run(cfg=effective_cfg, dataset=dataset, metrics=metrics, attempt=attempt)

            decision = debugger_agent.run(cfg=effective_cfg, metrics=metrics, eval_result=eval_result, attempt=attempt)
            if decision["action"] == "accept":
                summary = {"attempts": attempt, "final_cfg": effective_cfg.to_dict(), "eval": eval_result}
                break

            updated = TrainConfig.from_dict(decision["next_cfg"])
            effective_cfg = updated

        if summary is None:
            summary = {
                "attempts": attempt,
                "final_cfg": effective_cfg.to_dict(),
                "eval": eval_result,
                "note": "达到最大重试次数,未满足 DebuggerAgent 的通过条件。",
            }

        store.write_json("config.json", summary["final_cfg"])
        store.write_json("run_summary.json", summary)

        reporter_agent.run(cfg=TrainConfig.from_dict(summary["final_cfg"]), dataset=dataset, summary=summary)

        bus.send(sender="Orchestrator", receiver="All", kind="run_finished", payload=summary)
        store.write_text("transcript.jsonl", bus.to_jsonl())

        return store, summary