aarushgupta's picture
Deploy FATHOM-DM Space bundle
2803d7e verified
from __future__ import annotations
import json
import os
from contextlib import contextmanager
from dataclasses import asdict, dataclass, replace
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Callable, Iterator
from .grpo import DMClosedLoopConfig, GRPOLaunchConfig, run_dm_grpo, run_hero_grpo
@dataclass(frozen=True)
class JointTrainingConfig:
root_dir: Path
cycles: int
hero_config: GRPOLaunchConfig
dm_config: GRPOLaunchConfig
target_ratios: list[float] | None = None
hero_world_path: Path | None = None
interface_provider: str | None = None
interface_model: str | None = None
interface_narrate: bool = False
interface_translation_mode: str | None = None
hero_max_game_steps: int = 40
hero_max_tool_calls: int = 80
hero_max_tool_calling_iterations: int = 32
def run_joint_training_loop(config: JointTrainingConfig) -> Path:
if config.cycles < 1:
raise ValueError("cycles must be at least 1.")
config.root_dir.mkdir(parents=True, exist_ok=True)
latest_hero_adapter = _initial_adapter_path(config.hero_config.resume_adapter_path)
latest_dm_adapter = _initial_adapter_path(config.dm_config.resume_adapter_path)
phases: list[dict[str, Any]] = []
_write_manifest(config, phases, status="running")
try:
for cycle_index in range(config.cycles):
cycle_number = cycle_index + 1
cycle_dir = config.root_dir / f"cycle_{cycle_number:02d}"
hero_dir = cycle_dir / "hero"
dm_dir = cycle_dir / "dm"
hero_result = _run_or_resume_hero_phase(
config=config,
cycle_number=cycle_number,
output_dir=hero_dir,
resume_adapter_path=latest_hero_adapter,
phases=phases,
on_phase_state_change=lambda: _write_manifest(config, phases, status="running"),
)
latest_hero_adapter = hero_result
_write_manifest(config, phases, status="running")
dm_result = _run_or_resume_dm_phase(
config=config,
cycle_number=cycle_number,
output_dir=dm_dir,
resume_adapter_path=latest_dm_adapter,
hero_adapter_path=latest_hero_adapter,
phases=phases,
on_phase_state_change=lambda: _write_manifest(config, phases, status="running"),
)
latest_dm_adapter = dm_result
_write_manifest(config, phases, status="running")
except Exception as exc:
_write_manifest(config, phases, status="failed", error=str(exc))
raise
_write_manifest(
config,
phases,
status="completed",
latest_hero_adapter_path=str(latest_hero_adapter) if latest_hero_adapter is not None else None,
latest_dm_adapter_path=str(latest_dm_adapter) if latest_dm_adapter is not None else None,
)
return config.root_dir
def _run_or_resume_hero_phase(
*,
config: JointTrainingConfig,
cycle_number: int,
output_dir: Path,
resume_adapter_path: Path | None,
phases: list[dict[str, Any]],
on_phase_state_change: Callable[[], None] | None = None,
) -> Path:
state_path = output_dir / "phase_state.json"
existing_state = _load_phase_state(state_path)
if existing_state is not None and existing_state.get("status") == "completed":
phases.append(existing_state)
return output_dir
output_dir.mkdir(parents=True, exist_ok=True)
run_name = config.hero_config.run_name or f"{config.root_dir.name}-hero-cycle-{cycle_number:02d}"
phase_state = {
"phase": "hero",
"cycle": cycle_number,
"status": "running",
"run_name": run_name,
"output_dir": str(output_dir),
"resume_adapter_path": None if resume_adapter_path is None else str(resume_adapter_path),
"started_at": _utc_now(),
}
phases.append(phase_state)
_write_json(state_path, phase_state)
if on_phase_state_change is not None:
on_phase_state_change()
phase_config = replace(
config.hero_config,
output_dir=output_dir,
run_name=run_name,
resume_adapter_path=None if resume_adapter_path is None else str(resume_adapter_path),
)
with _wandb_phase_env(group=config.root_dir.name, job_type="hero"):
run_hero_grpo(
phase_config,
world_path=config.hero_world_path,
artifacts_root=output_dir / "artifacts",
interface_provider=config.interface_provider,
interface_model=config.interface_model,
interface_narrate=config.interface_narrate,
interface_translation_mode=config.interface_translation_mode,
max_game_steps=config.hero_max_game_steps,
max_tool_calls=config.hero_max_tool_calls,
max_tool_calling_iterations=config.hero_max_tool_calling_iterations,
)
phase_state["status"] = "completed"
phase_state["completed_at"] = _utc_now()
_write_json(state_path, phase_state)
return output_dir
def _run_or_resume_dm_phase(
*,
config: JointTrainingConfig,
cycle_number: int,
output_dir: Path,
resume_adapter_path: Path | None,
hero_adapter_path: Path | None,
phases: list[dict[str, Any]],
on_phase_state_change: Callable[[], None] | None = None,
) -> Path:
if hero_adapter_path is None:
raise RuntimeError("DM phase requires a hero adapter path from a completed hero phase.")
state_path = output_dir / "phase_state.json"
existing_state = _load_phase_state(state_path)
if existing_state is not None and existing_state.get("status") == "completed":
phases.append(existing_state)
return output_dir
output_dir.mkdir(parents=True, exist_ok=True)
run_name = config.dm_config.run_name or f"{config.root_dir.name}-dm-cycle-{cycle_number:02d}"
phase_state = {
"phase": "dm",
"cycle": cycle_number,
"status": "running",
"run_name": run_name,
"output_dir": str(output_dir),
"resume_adapter_path": None if resume_adapter_path is None else str(resume_adapter_path),
"hero_adapter_path": str(hero_adapter_path),
"started_at": _utc_now(),
}
phases.append(phase_state)
_write_json(state_path, phase_state)
if on_phase_state_change is not None:
on_phase_state_change()
phase_config = replace(
config.dm_config,
output_dir=output_dir,
run_name=run_name,
resume_adapter_path=None if resume_adapter_path is None else str(resume_adapter_path),
)
closed_loop = DMClosedLoopConfig(
hero_provider="hf_local",
hero_model=config.hero_config.model_name,
hero_adapter_path=str(hero_adapter_path),
interface_provider=config.interface_provider,
interface_model=config.interface_model,
interface_narrate=config.interface_narrate,
interface_translation_mode=config.interface_translation_mode,
hero_max_game_steps=config.hero_max_game_steps,
hero_max_tool_calls=config.hero_max_tool_calls,
)
with _wandb_phase_env(group=config.root_dir.name, job_type="dm"):
run_dm_grpo(
phase_config,
target_ratios=config.target_ratios,
artifacts_root=output_dir / "artifacts",
closed_loop=closed_loop,
)
phase_state["status"] = "completed"
phase_state["completed_at"] = _utc_now()
_write_json(state_path, phase_state)
return output_dir
def _write_manifest(
config: JointTrainingConfig,
phases: list[dict[str, Any]],
*,
status: str,
error: str | None = None,
latest_hero_adapter_path: str | None = None,
latest_dm_adapter_path: str | None = None,
) -> None:
payload = {
"status": status,
"updated_at": _utc_now(),
"error": error,
"latest_hero_adapter_path": latest_hero_adapter_path,
"latest_dm_adapter_path": latest_dm_adapter_path,
"config": _to_jsonable(asdict(config)),
"phases": phases,
}
_write_json(config.root_dir / "joint_state.json", payload)
@contextmanager
def _wandb_phase_env(*, group: str, job_type: str) -> Iterator[None]:
previous_group = os.getenv("WANDB_RUN_GROUP")
previous_job_type = os.getenv("WANDB_JOB_TYPE")
os.environ["WANDB_RUN_GROUP"] = group
os.environ["WANDB_JOB_TYPE"] = job_type
try:
yield
finally:
_restore_env("WANDB_RUN_GROUP", previous_group)
_restore_env("WANDB_JOB_TYPE", previous_job_type)
def _restore_env(name: str, value: str | None) -> None:
if value is None:
os.environ.pop(name, None)
else:
os.environ[name] = value
def _load_phase_state(path: Path) -> dict[str, Any] | None:
if not path.exists():
return None
return json.loads(path.read_text(encoding="utf-8"))
def _write_json(path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(_to_jsonable(payload), indent=2, sort_keys=True) + "\n", encoding="utf-8")
def _to_jsonable(value: Any) -> Any:
if isinstance(value, Path):
return str(value)
if isinstance(value, dict):
return {str(key): _to_jsonable(item) for key, item in value.items()}
if isinstance(value, list):
return [_to_jsonable(item) for item in value]
return value
def _initial_adapter_path(raw_path: str | None) -> Path | None:
if raw_path is None:
return None
path = Path(raw_path)
return path if path.exists() else None
def _utc_now() -> str:
return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")