from __future__ import annotations from dataclasses import replace from pathlib import Path from typing import Any, Dict, Optional, Tuple from .artifacts import ArtifactStore from .bus import MessageBus from .config import TrainConfig from .agents.data_agent import DataAgent from .agents.debugger_agent import DebuggerAgent from .agents.evaluator_agent import EvaluatorAgent from .agents.reporter_agent import ReporterAgent from .agents.trainer_agent import TrainerAgent class Orchestrator: def __init__(self, *, base_runs_dir: Path) -> None: self.base_runs_dir = base_runs_dir def run(self, *, cfg: TrainConfig, scenario: str) -> Tuple[ArtifactStore, Dict[str, Any]]: store = ArtifactStore.create(self.base_runs_dir) bus = MessageBus() bus.send(sender="Orchestrator", receiver="All", kind="run_started", payload={"scenario": scenario}) data_agent = DataAgent(bus=bus, store=store) trainer_agent = TrainerAgent(bus=bus, store=store) evaluator_agent = EvaluatorAgent(bus=bus, store=store) debugger_agent = DebuggerAgent(bus=bus, store=store) reporter_agent = ReporterAgent(bus=bus, store=store) dataset = data_agent.run(cfg=cfg) effective_cfg = cfg if scenario == "unstable": effective_cfg = replace( effective_cfg, lr=max(6.0, effective_cfg.lr * 30), grad_clip=None, l2=0.0, loss_eps=0.0, ) elif scenario != "stable": raise ValueError(f"未知 scenario: {scenario}") store.write_json("config.initial.json", effective_cfg.to_dict()) summary: Optional[Dict[str, Any]] = None attempt = 0 max_attempts = 3 while attempt < max_attempts: attempt += 1 bus.send( sender="Orchestrator", receiver="All", kind="attempt_started", payload={"attempt": attempt, "cfg": effective_cfg.to_dict()}, ) metrics = trainer_agent.run(cfg=effective_cfg, dataset=dataset, attempt=attempt) eval_result = evaluator_agent.run(cfg=effective_cfg, dataset=dataset, metrics=metrics, attempt=attempt) decision = debugger_agent.run(cfg=effective_cfg, metrics=metrics, eval_result=eval_result, attempt=attempt) if decision["action"] == "accept": summary = {"attempts": attempt, "final_cfg": effective_cfg.to_dict(), "eval": eval_result} break updated = TrainConfig.from_dict(decision["next_cfg"]) effective_cfg = updated if summary is None: summary = { "attempts": attempt, "final_cfg": effective_cfg.to_dict(), "eval": eval_result, "note": "达到最大重试次数,未满足 DebuggerAgent 的通过条件。", } store.write_json("config.json", summary["final_cfg"]) store.write_json("run_summary.json", summary) reporter_agent.run(cfg=TrainConfig.from_dict(summary["final_cfg"]), dataset=dataset, summary=summary) bus.send(sender="Orchestrator", receiver="All", kind="run_finished", payload=summary) store.write_text("transcript.jsonl", bus.to_jsonl()) return store, summary