Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| import os | |
| from contextlib import contextmanager | |
| from dataclasses import asdict, dataclass, replace | |
| from datetime import UTC, datetime | |
| from pathlib import Path | |
| from typing import Any, Callable, Iterator | |
| from .grpo import DMClosedLoopConfig, GRPOLaunchConfig, run_dm_grpo, run_hero_grpo | |
| class JointTrainingConfig: | |
| root_dir: Path | |
| cycles: int | |
| hero_config: GRPOLaunchConfig | |
| dm_config: GRPOLaunchConfig | |
| target_ratios: list[float] | None = None | |
| hero_world_path: Path | None = None | |
| interface_provider: str | None = None | |
| interface_model: str | None = None | |
| interface_narrate: bool = False | |
| interface_translation_mode: str | None = None | |
| hero_max_game_steps: int = 40 | |
| hero_max_tool_calls: int = 80 | |
| hero_max_tool_calling_iterations: int = 32 | |
| def run_joint_training_loop(config: JointTrainingConfig) -> Path: | |
| if config.cycles < 1: | |
| raise ValueError("cycles must be at least 1.") | |
| config.root_dir.mkdir(parents=True, exist_ok=True) | |
| latest_hero_adapter = _initial_adapter_path(config.hero_config.resume_adapter_path) | |
| latest_dm_adapter = _initial_adapter_path(config.dm_config.resume_adapter_path) | |
| phases: list[dict[str, Any]] = [] | |
| _write_manifest(config, phases, status="running") | |
| try: | |
| for cycle_index in range(config.cycles): | |
| cycle_number = cycle_index + 1 | |
| cycle_dir = config.root_dir / f"cycle_{cycle_number:02d}" | |
| hero_dir = cycle_dir / "hero" | |
| dm_dir = cycle_dir / "dm" | |
| hero_result = _run_or_resume_hero_phase( | |
| config=config, | |
| cycle_number=cycle_number, | |
| output_dir=hero_dir, | |
| resume_adapter_path=latest_hero_adapter, | |
| phases=phases, | |
| on_phase_state_change=lambda: _write_manifest(config, phases, status="running"), | |
| ) | |
| latest_hero_adapter = hero_result | |
| _write_manifest(config, phases, status="running") | |
| dm_result = _run_or_resume_dm_phase( | |
| config=config, | |
| cycle_number=cycle_number, | |
| output_dir=dm_dir, | |
| resume_adapter_path=latest_dm_adapter, | |
| hero_adapter_path=latest_hero_adapter, | |
| phases=phases, | |
| on_phase_state_change=lambda: _write_manifest(config, phases, status="running"), | |
| ) | |
| latest_dm_adapter = dm_result | |
| _write_manifest(config, phases, status="running") | |
| except Exception as exc: | |
| _write_manifest(config, phases, status="failed", error=str(exc)) | |
| raise | |
| _write_manifest( | |
| config, | |
| phases, | |
| status="completed", | |
| latest_hero_adapter_path=str(latest_hero_adapter) if latest_hero_adapter is not None else None, | |
| latest_dm_adapter_path=str(latest_dm_adapter) if latest_dm_adapter is not None else None, | |
| ) | |
| return config.root_dir | |
| def _run_or_resume_hero_phase( | |
| *, | |
| config: JointTrainingConfig, | |
| cycle_number: int, | |
| output_dir: Path, | |
| resume_adapter_path: Path | None, | |
| phases: list[dict[str, Any]], | |
| on_phase_state_change: Callable[[], None] | None = None, | |
| ) -> Path: | |
| state_path = output_dir / "phase_state.json" | |
| existing_state = _load_phase_state(state_path) | |
| if existing_state is not None and existing_state.get("status") == "completed": | |
| phases.append(existing_state) | |
| return output_dir | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| run_name = config.hero_config.run_name or f"{config.root_dir.name}-hero-cycle-{cycle_number:02d}" | |
| phase_state = { | |
| "phase": "hero", | |
| "cycle": cycle_number, | |
| "status": "running", | |
| "run_name": run_name, | |
| "output_dir": str(output_dir), | |
| "resume_adapter_path": None if resume_adapter_path is None else str(resume_adapter_path), | |
| "started_at": _utc_now(), | |
| } | |
| phases.append(phase_state) | |
| _write_json(state_path, phase_state) | |
| if on_phase_state_change is not None: | |
| on_phase_state_change() | |
| phase_config = replace( | |
| config.hero_config, | |
| output_dir=output_dir, | |
| run_name=run_name, | |
| resume_adapter_path=None if resume_adapter_path is None else str(resume_adapter_path), | |
| ) | |
| with _wandb_phase_env(group=config.root_dir.name, job_type="hero"): | |
| run_hero_grpo( | |
| phase_config, | |
| world_path=config.hero_world_path, | |
| artifacts_root=output_dir / "artifacts", | |
| interface_provider=config.interface_provider, | |
| interface_model=config.interface_model, | |
| interface_narrate=config.interface_narrate, | |
| interface_translation_mode=config.interface_translation_mode, | |
| max_game_steps=config.hero_max_game_steps, | |
| max_tool_calls=config.hero_max_tool_calls, | |
| max_tool_calling_iterations=config.hero_max_tool_calling_iterations, | |
| ) | |
| phase_state["status"] = "completed" | |
| phase_state["completed_at"] = _utc_now() | |
| _write_json(state_path, phase_state) | |
| return output_dir | |
| def _run_or_resume_dm_phase( | |
| *, | |
| config: JointTrainingConfig, | |
| cycle_number: int, | |
| output_dir: Path, | |
| resume_adapter_path: Path | None, | |
| hero_adapter_path: Path | None, | |
| phases: list[dict[str, Any]], | |
| on_phase_state_change: Callable[[], None] | None = None, | |
| ) -> Path: | |
| if hero_adapter_path is None: | |
| raise RuntimeError("DM phase requires a hero adapter path from a completed hero phase.") | |
| state_path = output_dir / "phase_state.json" | |
| existing_state = _load_phase_state(state_path) | |
| if existing_state is not None and existing_state.get("status") == "completed": | |
| phases.append(existing_state) | |
| return output_dir | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| run_name = config.dm_config.run_name or f"{config.root_dir.name}-dm-cycle-{cycle_number:02d}" | |
| phase_state = { | |
| "phase": "dm", | |
| "cycle": cycle_number, | |
| "status": "running", | |
| "run_name": run_name, | |
| "output_dir": str(output_dir), | |
| "resume_adapter_path": None if resume_adapter_path is None else str(resume_adapter_path), | |
| "hero_adapter_path": str(hero_adapter_path), | |
| "started_at": _utc_now(), | |
| } | |
| phases.append(phase_state) | |
| _write_json(state_path, phase_state) | |
| if on_phase_state_change is not None: | |
| on_phase_state_change() | |
| phase_config = replace( | |
| config.dm_config, | |
| output_dir=output_dir, | |
| run_name=run_name, | |
| resume_adapter_path=None if resume_adapter_path is None else str(resume_adapter_path), | |
| ) | |
| closed_loop = DMClosedLoopConfig( | |
| hero_provider="hf_local", | |
| hero_model=config.hero_config.model_name, | |
| hero_adapter_path=str(hero_adapter_path), | |
| interface_provider=config.interface_provider, | |
| interface_model=config.interface_model, | |
| interface_narrate=config.interface_narrate, | |
| interface_translation_mode=config.interface_translation_mode, | |
| hero_max_game_steps=config.hero_max_game_steps, | |
| hero_max_tool_calls=config.hero_max_tool_calls, | |
| ) | |
| with _wandb_phase_env(group=config.root_dir.name, job_type="dm"): | |
| run_dm_grpo( | |
| phase_config, | |
| target_ratios=config.target_ratios, | |
| artifacts_root=output_dir / "artifacts", | |
| closed_loop=closed_loop, | |
| ) | |
| phase_state["status"] = "completed" | |
| phase_state["completed_at"] = _utc_now() | |
| _write_json(state_path, phase_state) | |
| return output_dir | |
| def _write_manifest( | |
| config: JointTrainingConfig, | |
| phases: list[dict[str, Any]], | |
| *, | |
| status: str, | |
| error: str | None = None, | |
| latest_hero_adapter_path: str | None = None, | |
| latest_dm_adapter_path: str | None = None, | |
| ) -> None: | |
| payload = { | |
| "status": status, | |
| "updated_at": _utc_now(), | |
| "error": error, | |
| "latest_hero_adapter_path": latest_hero_adapter_path, | |
| "latest_dm_adapter_path": latest_dm_adapter_path, | |
| "config": _to_jsonable(asdict(config)), | |
| "phases": phases, | |
| } | |
| _write_json(config.root_dir / "joint_state.json", payload) | |
| def _wandb_phase_env(*, group: str, job_type: str) -> Iterator[None]: | |
| previous_group = os.getenv("WANDB_RUN_GROUP") | |
| previous_job_type = os.getenv("WANDB_JOB_TYPE") | |
| os.environ["WANDB_RUN_GROUP"] = group | |
| os.environ["WANDB_JOB_TYPE"] = job_type | |
| try: | |
| yield | |
| finally: | |
| _restore_env("WANDB_RUN_GROUP", previous_group) | |
| _restore_env("WANDB_JOB_TYPE", previous_job_type) | |
| def _restore_env(name: str, value: str | None) -> None: | |
| if value is None: | |
| os.environ.pop(name, None) | |
| else: | |
| os.environ[name] = value | |
| def _load_phase_state(path: Path) -> dict[str, Any] | None: | |
| if not path.exists(): | |
| return None | |
| return json.loads(path.read_text(encoding="utf-8")) | |
| def _write_json(path: Path, payload: dict[str, Any]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(json.dumps(_to_jsonable(payload), indent=2, sort_keys=True) + "\n", encoding="utf-8") | |
| def _to_jsonable(value: Any) -> Any: | |
| if isinstance(value, Path): | |
| return str(value) | |
| if isinstance(value, dict): | |
| return {str(key): _to_jsonable(item) for key, item in value.items()} | |
| if isinstance(value, list): | |
| return [_to_jsonable(item) for item in value] | |
| return value | |
| def _initial_adapter_path(raw_path: str | None) -> Path | None: | |
| if raw_path is None: | |
| return None | |
| path = Path(raw_path) | |
| return path if path.exists() else None | |
| def _utc_now() -> str: | |
| return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z") | |