Spaces:
Running
Running
| from __future__ import annotations | |
| from dataclasses import replace | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Tuple | |
| from .artifacts import ArtifactStore | |
| from .bus import MessageBus | |
| from .config import TrainConfig | |
| from .agents.data_agent import DataAgent | |
| from .agents.debugger_agent import DebuggerAgent | |
| from .agents.evaluator_agent import EvaluatorAgent | |
| from .agents.reporter_agent import ReporterAgent | |
| from .agents.trainer_agent import TrainerAgent | |
| class Orchestrator: | |
| def __init__(self, *, base_runs_dir: Path) -> None: | |
| self.base_runs_dir = base_runs_dir | |
| def run(self, *, cfg: TrainConfig, scenario: str) -> Tuple[ArtifactStore, Dict[str, Any]]: | |
| store = ArtifactStore.create(self.base_runs_dir) | |
| bus = MessageBus() | |
| bus.send(sender="Orchestrator", receiver="All", kind="run_started", payload={"scenario": scenario}) | |
| data_agent = DataAgent(bus=bus, store=store) | |
| trainer_agent = TrainerAgent(bus=bus, store=store) | |
| evaluator_agent = EvaluatorAgent(bus=bus, store=store) | |
| debugger_agent = DebuggerAgent(bus=bus, store=store) | |
| reporter_agent = ReporterAgent(bus=bus, store=store) | |
| dataset = data_agent.run(cfg=cfg) | |
| effective_cfg = cfg | |
| if scenario == "unstable": | |
| effective_cfg = replace( | |
| effective_cfg, | |
| lr=max(6.0, effective_cfg.lr * 30), | |
| grad_clip=None, | |
| l2=0.0, | |
| loss_eps=0.0, | |
| ) | |
| elif scenario != "stable": | |
| raise ValueError(f"未知 scenario: {scenario}") | |
| store.write_json("config.initial.json", effective_cfg.to_dict()) | |
| summary: Optional[Dict[str, Any]] = None | |
| attempt = 0 | |
| max_attempts = 3 | |
| while attempt < max_attempts: | |
| attempt += 1 | |
| bus.send( | |
| sender="Orchestrator", | |
| receiver="All", | |
| kind="attempt_started", | |
| payload={"attempt": attempt, "cfg": effective_cfg.to_dict()}, | |
| ) | |
| metrics = trainer_agent.run(cfg=effective_cfg, dataset=dataset, attempt=attempt) | |
| eval_result = evaluator_agent.run(cfg=effective_cfg, dataset=dataset, metrics=metrics, attempt=attempt) | |
| decision = debugger_agent.run(cfg=effective_cfg, metrics=metrics, eval_result=eval_result, attempt=attempt) | |
| if decision["action"] == "accept": | |
| summary = {"attempts": attempt, "final_cfg": effective_cfg.to_dict(), "eval": eval_result} | |
| break | |
| updated = TrainConfig.from_dict(decision["next_cfg"]) | |
| effective_cfg = updated | |
| if summary is None: | |
| summary = { | |
| "attempts": attempt, | |
| "final_cfg": effective_cfg.to_dict(), | |
| "eval": eval_result, | |
| "note": "达到最大重试次数,未满足 DebuggerAgent 的通过条件。", | |
| } | |
| store.write_json("config.json", summary["final_cfg"]) | |
| store.write_json("run_summary.json", summary) | |
| reporter_agent.run(cfg=TrainConfig.from_dict(summary["final_cfg"]), dataset=dataset, summary=summary) | |
| bus.send(sender="Orchestrator", receiver="All", kind="run_finished", payload=summary) | |
| store.write_text("transcript.jsonl", bus.to_jsonl()) | |
| return store, summary | |