from __future__ import annotations import json import os from contextlib import contextmanager from dataclasses import asdict, dataclass, replace from datetime import UTC, datetime from pathlib import Path from typing import Any, Callable, Iterator from .grpo import DMClosedLoopConfig, GRPOLaunchConfig, run_dm_grpo, run_hero_grpo @dataclass(frozen=True) class JointTrainingConfig: root_dir: Path cycles: int hero_config: GRPOLaunchConfig dm_config: GRPOLaunchConfig target_ratios: list[float] | None = None hero_world_path: Path | None = None interface_provider: str | None = None interface_model: str | None = None interface_narrate: bool = False interface_translation_mode: str | None = None hero_max_game_steps: int = 40 hero_max_tool_calls: int = 80 hero_max_tool_calling_iterations: int = 32 def run_joint_training_loop(config: JointTrainingConfig) -> Path: if config.cycles < 1: raise ValueError("cycles must be at least 1.") config.root_dir.mkdir(parents=True, exist_ok=True) latest_hero_adapter = _initial_adapter_path(config.hero_config.resume_adapter_path) latest_dm_adapter = _initial_adapter_path(config.dm_config.resume_adapter_path) phases: list[dict[str, Any]] = [] _write_manifest(config, phases, status="running") try: for cycle_index in range(config.cycles): cycle_number = cycle_index + 1 cycle_dir = config.root_dir / f"cycle_{cycle_number:02d}" hero_dir = cycle_dir / "hero" dm_dir = cycle_dir / "dm" hero_result = _run_or_resume_hero_phase( config=config, cycle_number=cycle_number, output_dir=hero_dir, resume_adapter_path=latest_hero_adapter, phases=phases, on_phase_state_change=lambda: _write_manifest(config, phases, status="running"), ) latest_hero_adapter = hero_result _write_manifest(config, phases, status="running") dm_result = _run_or_resume_dm_phase( config=config, cycle_number=cycle_number, output_dir=dm_dir, resume_adapter_path=latest_dm_adapter, hero_adapter_path=latest_hero_adapter, phases=phases, on_phase_state_change=lambda: _write_manifest(config, phases, status="running"), ) latest_dm_adapter = dm_result _write_manifest(config, phases, status="running") except Exception as exc: _write_manifest(config, phases, status="failed", error=str(exc)) raise _write_manifest( config, phases, status="completed", latest_hero_adapter_path=str(latest_hero_adapter) if latest_hero_adapter is not None else None, latest_dm_adapter_path=str(latest_dm_adapter) if latest_dm_adapter is not None else None, ) return config.root_dir def _run_or_resume_hero_phase( *, config: JointTrainingConfig, cycle_number: int, output_dir: Path, resume_adapter_path: Path | None, phases: list[dict[str, Any]], on_phase_state_change: Callable[[], None] | None = None, ) -> Path: state_path = output_dir / "phase_state.json" existing_state = _load_phase_state(state_path) if existing_state is not None and existing_state.get("status") == "completed": phases.append(existing_state) return output_dir output_dir.mkdir(parents=True, exist_ok=True) run_name = config.hero_config.run_name or f"{config.root_dir.name}-hero-cycle-{cycle_number:02d}" phase_state = { "phase": "hero", "cycle": cycle_number, "status": "running", "run_name": run_name, "output_dir": str(output_dir), "resume_adapter_path": None if resume_adapter_path is None else str(resume_adapter_path), "started_at": _utc_now(), } phases.append(phase_state) _write_json(state_path, phase_state) if on_phase_state_change is not None: on_phase_state_change() phase_config = replace( config.hero_config, output_dir=output_dir, run_name=run_name, resume_adapter_path=None if resume_adapter_path is None else str(resume_adapter_path), ) with _wandb_phase_env(group=config.root_dir.name, job_type="hero"): run_hero_grpo( phase_config, world_path=config.hero_world_path, artifacts_root=output_dir / "artifacts", interface_provider=config.interface_provider, interface_model=config.interface_model, interface_narrate=config.interface_narrate, interface_translation_mode=config.interface_translation_mode, max_game_steps=config.hero_max_game_steps, max_tool_calls=config.hero_max_tool_calls, max_tool_calling_iterations=config.hero_max_tool_calling_iterations, ) phase_state["status"] = "completed" phase_state["completed_at"] = _utc_now() _write_json(state_path, phase_state) return output_dir def _run_or_resume_dm_phase( *, config: JointTrainingConfig, cycle_number: int, output_dir: Path, resume_adapter_path: Path | None, hero_adapter_path: Path | None, phases: list[dict[str, Any]], on_phase_state_change: Callable[[], None] | None = None, ) -> Path: if hero_adapter_path is None: raise RuntimeError("DM phase requires a hero adapter path from a completed hero phase.") state_path = output_dir / "phase_state.json" existing_state = _load_phase_state(state_path) if existing_state is not None and existing_state.get("status") == "completed": phases.append(existing_state) return output_dir output_dir.mkdir(parents=True, exist_ok=True) run_name = config.dm_config.run_name or f"{config.root_dir.name}-dm-cycle-{cycle_number:02d}" phase_state = { "phase": "dm", "cycle": cycle_number, "status": "running", "run_name": run_name, "output_dir": str(output_dir), "resume_adapter_path": None if resume_adapter_path is None else str(resume_adapter_path), "hero_adapter_path": str(hero_adapter_path), "started_at": _utc_now(), } phases.append(phase_state) _write_json(state_path, phase_state) if on_phase_state_change is not None: on_phase_state_change() phase_config = replace( config.dm_config, output_dir=output_dir, run_name=run_name, resume_adapter_path=None if resume_adapter_path is None else str(resume_adapter_path), ) closed_loop = DMClosedLoopConfig( hero_provider="hf_local", hero_model=config.hero_config.model_name, hero_adapter_path=str(hero_adapter_path), interface_provider=config.interface_provider, interface_model=config.interface_model, interface_narrate=config.interface_narrate, interface_translation_mode=config.interface_translation_mode, hero_max_game_steps=config.hero_max_game_steps, hero_max_tool_calls=config.hero_max_tool_calls, ) with _wandb_phase_env(group=config.root_dir.name, job_type="dm"): run_dm_grpo( phase_config, target_ratios=config.target_ratios, artifacts_root=output_dir / "artifacts", closed_loop=closed_loop, ) phase_state["status"] = "completed" phase_state["completed_at"] = _utc_now() _write_json(state_path, phase_state) return output_dir def _write_manifest( config: JointTrainingConfig, phases: list[dict[str, Any]], *, status: str, error: str | None = None, latest_hero_adapter_path: str | None = None, latest_dm_adapter_path: str | None = None, ) -> None: payload = { "status": status, "updated_at": _utc_now(), "error": error, "latest_hero_adapter_path": latest_hero_adapter_path, "latest_dm_adapter_path": latest_dm_adapter_path, "config": _to_jsonable(asdict(config)), "phases": phases, } _write_json(config.root_dir / "joint_state.json", payload) @contextmanager def _wandb_phase_env(*, group: str, job_type: str) -> Iterator[None]: previous_group = os.getenv("WANDB_RUN_GROUP") previous_job_type = os.getenv("WANDB_JOB_TYPE") os.environ["WANDB_RUN_GROUP"] = group os.environ["WANDB_JOB_TYPE"] = job_type try: yield finally: _restore_env("WANDB_RUN_GROUP", previous_group) _restore_env("WANDB_JOB_TYPE", previous_job_type) def _restore_env(name: str, value: str | None) -> None: if value is None: os.environ.pop(name, None) else: os.environ[name] = value def _load_phase_state(path: Path) -> dict[str, Any] | None: if not path.exists(): return None return json.loads(path.read_text(encoding="utf-8")) def _write_json(path: Path, payload: dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(_to_jsonable(payload), indent=2, sort_keys=True) + "\n", encoding="utf-8") def _to_jsonable(value: Any) -> Any: if isinstance(value, Path): return str(value) if isinstance(value, dict): return {str(key): _to_jsonable(item) for key, item in value.items()} if isinstance(value, list): return [_to_jsonable(item) for item in value] return value def _initial_adapter_path(raw_path: str | None) -> Path | None: if raw_path is None: return None path = Path(raw_path) return path if path.exists() else None def _utc_now() -> str: return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")