FATHOM-Hero / agents /train /grpo.py
aarushgupta's picture
Deploy FATHOM-Hero Space bundle
c782fbf verified
from __future__ import annotations
import hashlib
import importlib.util
import json
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
from agents.hero.cli import parse_cli_command
from agents.hero.env import HeroEnvironment
from agents.hero.policy import HeroLLMPolicy
from agents.hero.runner import HeroRunner
from agents.master.base import normalize_answer_text, parser_safe_text
from agents.master.check import validate_and_normalize
from agents.hero.prompt import format_hero_grpo_system_prompt
from agents.hero.schema import validate_hero_action
from agents.master.env import DMEnvironment
from agents.master.prompt import build_dm_world_messages
from agents.master.sample import load_world, sample_world_definition
from agents.master.schema import WorldDefinition
from agents.shared.runtime import (
build_interface_adapter,
create_structured_client,
resolve_interface_config,
resolve_structured_client_config,
)
try:
import torch
from datasets import Dataset
from peft import LoraConfig
from trl.chat_template_utils import qwen3_chat_template, qwen3_schema
from trl.rewards import get_soft_overlong_punishment
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoTokenizer, BitsAndBytesConfig
TRAINING_IMPORT_ERROR: Exception | None = None
except Exception as exc: # pragma: no cover - exercised when train extras are unavailable
torch = None # type: ignore[assignment]
Dataset = None # type: ignore[assignment]
LoraConfig = None # type: ignore[assignment]
GRPOConfig = None # type: ignore[assignment]
GRPOTrainer = None # type: ignore[assignment]
AutoTokenizer = None # type: ignore[assignment]
BitsAndBytesConfig = None # type: ignore[assignment]
qwen3_chat_template = None # type: ignore[assignment]
qwen3_schema = None # type: ignore[assignment]
get_soft_overlong_punishment = None # type: ignore[assignment]
TRAINING_IMPORT_ERROR = exc
_DEFAULT_TARGET_RATIOS = [1.25, 1.5, 1.75, 2.0]
_DM_REQUIRED_TOP_LEVEL_FIELDS = ("meta", "nodes", "edges", "items", "clues", "recipes", "quest_chain")
_DM_ALLOWED_NODE_TYPES = {"location", "junction", "container", "door", "readable", "fixture", "npc"}
_DM_ALLOWED_EDGE_TYPES = {"passage", "locked_passage"}
_DM_ALLOWED_ITEM_TYPES = {"key", "puzzle"}
_HERO_TOOL_NAMES = {"act", "scratchpad_read", "scratchpad_write"}
_TOOL_CALL_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
_EMPTY_THINK_RE = re.compile(r"<think>\s*</think>\s*", re.DOTALL)
_LOWERCASE_ANSWER_RE = re.compile(r"^[a-z0-9]+(?: [a-z0-9]+)*$")
_HERO_TASK_PROMPTS = (
"Solve the dungeon by using tools until the episode ends.\nInitial observation:\n",
"Play the dungeon to completion through tool calls only.\nInitial observation:\n",
"Gather every clue and solve the dungeon via tools only.\nInitial observation:\n",
)
SUPPORTED_GRPO_LOSS_TYPES = ("grpo", "dapo", "bnpo", "dr_grpo", "cispo", "sapo", "luspo")
SUPPORTED_IMPORTANCE_SAMPLING_LEVELS = ("token", "sequence")
@dataclass(frozen=True)
class GRPOLaunchConfig:
model_name: str
output_dir: Path
resume_adapter_path: str | None = None
max_steps: int = 10
num_prompts: int = 16
learning_rate: float = 1e-5
per_device_train_batch_size: int = 2
gradient_accumulation_steps: int = 8
num_generations: int = 2
max_completion_length: int = 512
logging_steps: int = 1
save_steps: int = 10
seed: int = 42
rank: int = 16
alpha: int = 32
dropout: float = 0.05
temperature: float = 0.6
top_p: float = 0.95
top_k: int = 20
min_p: float | None = None
repetition_penalty: float = 1.0
use_wandb: bool = True
run_name: str | None = None
trust_remote_code: bool = False
load_in_4bit: bool = True
loss_type: str = "dapo"
importance_sampling_level: str = "token"
use_transformers_paged: bool = False
cache_implementation: str | None = None
use_vllm: bool = False
vllm_mode: str = "colocate"
vllm_gpu_memory_utilization: float = 0.2
vllm_enable_sleep_mode: bool = True
def __post_init__(self) -> None:
if self.loss_type not in SUPPORTED_GRPO_LOSS_TYPES:
raise ValueError(
f"loss_type must be one of {SUPPORTED_GRPO_LOSS_TYPES!r}; got {self.loss_type!r}."
)
if self.importance_sampling_level not in SUPPORTED_IMPORTANCE_SAMPLING_LEVELS:
raise ValueError(
"importance_sampling_level must be one of "
f"{SUPPORTED_IMPORTANCE_SAMPLING_LEVELS!r}; got {self.importance_sampling_level!r}."
)
if self.loss_type == "luspo" and self.importance_sampling_level != "sequence":
raise ValueError("luspo requires importance_sampling_level='sequence'.")
if self.per_device_train_batch_size < 1:
raise ValueError("per_device_train_batch_size must be at least 1.")
if self.gradient_accumulation_steps < 1:
raise ValueError("gradient_accumulation_steps must be at least 1.")
if self.num_generations < 2:
raise ValueError("num_generations must be at least 2 for GRPO.")
if self.max_steps < 1:
raise ValueError("max_steps must be at least 1.")
if self.num_prompts < 1:
raise ValueError("num_prompts must be at least 1.")
if self.temperature <= 0.0:
raise ValueError("temperature must be greater than 0.")
if not 0.0 < self.top_p <= 1.0:
raise ValueError("top_p must be in the interval (0, 1].")
if self.top_k < 0:
raise ValueError("top_k must be non-negative.")
if self.min_p is not None and not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in the interval [0, 1] when provided.")
if self.repetition_penalty < 1.0:
raise ValueError("repetition_penalty must be at least 1.0.")
if self.vllm_mode not in {"server", "colocate"}:
raise ValueError("vllm_mode must be 'server' or 'colocate'.")
if not 0.0 < self.vllm_gpu_memory_utilization < 1.0:
raise ValueError("vllm_gpu_memory_utilization must be in the interval (0, 1).")
world_size = max(1, int(os.environ.get("WORLD_SIZE", "1")))
generation_batch_size = self.per_device_train_batch_size * world_size
if generation_batch_size % self.num_generations != 0:
raise ValueError(
"generation_batch_size "
f"({generation_batch_size}) must be divisible by num_generations ({self.num_generations}). "
"Increase --per-device-train-batch-size, reduce --num-generations, or launch with more processes."
)
minimum_prompt_rows = generation_batch_size * self.gradient_accumulation_steps
if self.num_prompts < minimum_prompt_rows:
raise ValueError(
"num_prompts "
f"({self.num_prompts}) must be at least generation_batch_size * gradient_accumulation_steps "
f"({minimum_prompt_rows}) so GRPO can complete one optimizer step."
)
@dataclass(frozen=True)
class DMClosedLoopConfig:
hero_provider: str | None = None
hero_model: str | None = None
hero_adapter_path: str | None = None
interface_provider: str | None = None
interface_model: str | None = None
interface_narrate: bool = False
interface_translation_mode: str | None = None
hero_max_game_steps: int = 40
hero_max_tool_calls: int = 80
@dataclass(frozen=True)
class DMRolloutMetrics:
reward: float
compile_error: str | None
requested_ratio: float
player_won: bool
steps_taken: int | None
min_steps: int | None
ratio: float | None
efficiency_score: float
quality_score: float
invalid_command_count: int
wrong_submit_count: int
hero_player_won: bool
hero_total_reward: float
hero_dense_return: float
hero_steps_taken: int
hero_tool_calls_total: int
hero_policy_error: str | None
_DM_ROLLOUT_CACHE_STEP = -1
_DM_ROLLOUT_CACHE: dict[tuple[Any, ...], DMRolloutMetrics] = {}
def build_dm_grpo_dataset(
*,
num_prompts: int = 8,
target_ratios: list[float] | None = None,
) -> list[dict[str, Any]]:
ratios = target_ratios or _DEFAULT_TARGET_RATIOS
rows: list[dict[str, Any]] = []
for index in range(num_prompts):
target_ratio = ratios[index % len(ratios)]
reference_world = sample_world_definition(seed=index, difficulty_target=target_ratio)
prompt = [
{"role": message.role, "content": message.content}
for message in build_dm_world_messages(
target_ratio=target_ratio,
reference_world=reference_world,
prompt_style=index,
)
]
rows.append({"prompt": prompt, "target_ratio": target_ratio, "seed": index})
return rows
def build_hero_grpo_dataset(
*,
num_prompts: int = 8,
world_input: dict[str, Any] | None = None,
target_ratios: list[float] | None = None,
max_game_steps: int = 40,
max_tool_calls: int = 80,
) -> list[dict[str, Any]]:
ratios = target_ratios or _DEFAULT_TARGET_RATIOS
rows: list[dict[str, Any]] = []
for index in range(num_prompts):
target_ratio = ratios[index % len(ratios)]
world = world_input or sample_world_definition(seed=index, difficulty_target=target_ratio)
world_title = str(world["meta"]["title"])
prompt = [
{
"role": "system",
"content": format_hero_grpo_system_prompt(world_title, max_game_steps, max_tool_calls),
},
{
"role": "user",
"content": _HERO_TASK_PROMPTS[index % len(_HERO_TASK_PROMPTS)],
},
]
rows.append(
{
"prompt": prompt,
"world_definition_json": json.dumps(world, separators=(",", ":")),
"seed": index,
"target_ratio": target_ratio,
"max_game_steps": max_game_steps,
"max_tool_calls": max_tool_calls,
}
)
return rows
class HeroToolEnvironment:
def __init__(
self,
*,
artifacts_root: Path | None = None,
interface_provider: str | None = None,
interface_model: str | None = None,
interface_narrate: bool = False,
interface_translation_mode: str | None = None,
) -> None:
self.artifacts_root = artifacts_root
self.interface_provider = interface_provider
self.interface_model = interface_model
self.interface_narrate = interface_narrate
self.interface_translation_mode = interface_translation_mode
self.hero_env: HeroEnvironment | None = None
self.last_message = ""
def reset(
self,
*,
world_definition_json: str,
seed: int | None = None,
max_game_steps: int = 40,
max_tool_calls: int = 80,
prompt: Any | None = None,
**_: Any,
) -> str:
del prompt
interface_adapter = build_interface_adapter(
resolve_interface_config(
provider=self.interface_provider, # type: ignore[arg-type]
model_name=self.interface_model,
narrate_observations=self.interface_narrate,
translation_mode=self.interface_translation_mode, # type: ignore[arg-type]
)
)
self.hero_env = HeroEnvironment(
artifacts_root=self.artifacts_root,
interface_adapter=interface_adapter,
)
observation = self.hero_env.reset(
world_input=json.loads(world_definition_json),
seed=seed,
max_game_steps=max_game_steps,
max_tool_calls=max_tool_calls,
)
self.last_message = observation.message
return observation.message
def act(self, command: str) -> str:
"""Act in the dungeon with one strict CLI command.
Args:
command: Lowercase parser-style dungeon command.
Returns:
The environment's next observation message.
"""
return self._step({"tool": "act", "command": command})
def scratchpad_read(self) -> str:
"""Read the current scratchpad contents.
Returns:
The scratchpad text.
"""
return self._step({"tool": "scratchpad_read"})
def scratchpad_write(self, mode: str, content: str) -> str:
"""Write to the scratchpad.
Args:
mode: Either append or replace.
content: Text to write.
Returns:
The environment's acknowledgement message.
"""
return self._step({"tool": "scratchpad_write", "mode": mode, "content": content})
def _cumulative_reward(self) -> float:
if self.hero_env is None:
return -1.0
return float(self.hero_env.episode_stats.total_reward)
def _episode_done(self) -> bool:
if self.hero_env is None or self.hero_env.session is None:
return False
return bool(self.hero_env.session.done or self.hero_env.state.status in {"won", "lost", "timed_out"})
def _episode_won(self) -> bool:
if self.hero_env is None:
return False
return bool(self.hero_env.episode_stats.player_won)
def _step(self, action: dict[str, Any]) -> str:
if self.hero_env is None:
raise RuntimeError("HeroToolEnvironment.reset must be called before using tools.")
result = self.hero_env.step(action)
self.last_message = result.observation.message
return result.observation.message
def dm_reward_function(
*,
prompts: list[Any],
completions: list[Any],
target_ratio: list[float],
trainer_state: Any,
hero_policy: Any,
interface_provider: str | None = None,
interface_model: str | None = None,
interface_narrate: bool = False,
interface_translation_mode: str | None = None,
hero_max_game_steps: int = 40,
hero_max_tool_calls: int = 80,
artifacts_root: str | None = None,
**_: Any,
) -> list[float]:
del prompts
rewards: list[float] = []
for index, (completion, requested_ratio) in enumerate(zip(completions, target_ratio, strict=True)):
metrics = _cached_dm_rollout_metrics(
completion=completion,
requested_ratio=requested_ratio,
trainer_state=trainer_state,
completion_index=index,
hero_policy=hero_policy,
interface_provider=interface_provider,
interface_model=interface_model,
interface_narrate=interface_narrate,
interface_translation_mode=interface_translation_mode,
hero_max_game_steps=hero_max_game_steps,
hero_max_tool_calls=hero_max_tool_calls,
artifacts_root=artifacts_root,
)
if metrics.compile_error is not None:
rewards.append(_compile_error_penalty(metrics.compile_error))
continue
rewards.append(metrics.reward)
return rewards
def _dm_reward_artifacts_dir(
*,
artifacts_root: str | None,
trainer_state: Any,
completion_index: int,
) -> Path | None:
if artifacts_root is None:
return None
step = getattr(trainer_state, "global_step", 0)
return Path(artifacts_root) / "dm_reward_rollouts" / f"step_{step:05d}" / f"sample_{completion_index:02d}"
def dm_hero_success_reward(
*,
prompts: list[Any],
completions: list[Any],
target_ratio: list[float],
trainer_state: Any,
hero_policy: Any,
interface_provider: str | None = None,
interface_model: str | None = None,
interface_narrate: bool = False,
interface_translation_mode: str | None = None,
hero_max_game_steps: int = 40,
hero_max_tool_calls: int = 80,
artifacts_root: str | None = None,
**_: Any,
) -> list[float]:
del prompts
rewards: list[float] = []
for index, (completion, requested_ratio) in enumerate(zip(completions, target_ratio, strict=True)):
metrics = _cached_dm_rollout_metrics(
completion=completion,
requested_ratio=requested_ratio,
trainer_state=trainer_state,
completion_index=index,
hero_policy=hero_policy,
interface_provider=interface_provider,
interface_model=interface_model,
interface_narrate=interface_narrate,
interface_translation_mode=interface_translation_mode,
hero_max_game_steps=hero_max_game_steps,
hero_max_tool_calls=hero_max_tool_calls,
artifacts_root=artifacts_root,
)
if metrics.compile_error is not None:
rewards.append(_compile_error_penalty(metrics.compile_error))
continue
rewards.append(float(metrics.hero_player_won))
return rewards
def dm_hero_efficiency_reward(
*,
prompts: list[Any],
completions: list[Any],
target_ratio: list[float],
trainer_state: Any,
hero_policy: Any,
interface_provider: str | None = None,
interface_model: str | None = None,
interface_narrate: bool = False,
interface_translation_mode: str | None = None,
hero_max_game_steps: int = 40,
hero_max_tool_calls: int = 80,
artifacts_root: str | None = None,
**_: Any,
) -> list[float]:
del prompts
rewards: list[float] = []
for index, (completion, requested_ratio) in enumerate(zip(completions, target_ratio, strict=True)):
metrics = _cached_dm_rollout_metrics(
completion=completion,
requested_ratio=requested_ratio,
trainer_state=trainer_state,
completion_index=index,
hero_policy=hero_policy,
interface_provider=interface_provider,
interface_model=interface_model,
interface_narrate=interface_narrate,
interface_translation_mode=interface_translation_mode,
hero_max_game_steps=hero_max_game_steps,
hero_max_tool_calls=hero_max_tool_calls,
artifacts_root=artifacts_root,
)
if metrics.compile_error is not None:
rewards.append(_compile_error_penalty(metrics.compile_error))
continue
if not metrics.hero_player_won:
rewards.append(0.0)
continue
rewards.append(_clamp(metrics.efficiency_score, 0.0, 1.0))
return rewards
def dm_hero_cleanliness_reward(
*,
prompts: list[Any],
completions: list[Any],
target_ratio: list[float],
trainer_state: Any,
hero_policy: Any,
interface_provider: str | None = None,
interface_model: str | None = None,
interface_narrate: bool = False,
interface_translation_mode: str | None = None,
hero_max_game_steps: int = 40,
hero_max_tool_calls: int = 80,
artifacts_root: str | None = None,
**_: Any,
) -> list[float]:
del prompts
rewards: list[float] = []
for index, (completion, requested_ratio) in enumerate(zip(completions, target_ratio, strict=True)):
metrics = _cached_dm_rollout_metrics(
completion=completion,
requested_ratio=requested_ratio,
trainer_state=trainer_state,
completion_index=index,
hero_policy=hero_policy,
interface_provider=interface_provider,
interface_model=interface_model,
interface_narrate=interface_narrate,
interface_translation_mode=interface_translation_mode,
hero_max_game_steps=hero_max_game_steps,
hero_max_tool_calls=hero_max_tool_calls,
artifacts_root=artifacts_root,
)
if metrics.compile_error is not None:
rewards.append(_compile_error_penalty(metrics.compile_error))
continue
step_budget = max(1, metrics.hero_steps_taken or metrics.steps_taken or 0)
penalty = (metrics.invalid_command_count + (2 * metrics.wrong_submit_count)) / step_budget
score = max(0.0, 1.0 - penalty)
if metrics.hero_policy_error is not None:
score = min(score, 0.25)
rewards.append(_clamp(score, 0.0, 1.0))
return rewards
def dm_json_format_reward(
*,
prompts: list[Any],
completions: list[Any],
trainer_state: Any,
**_: Any,
) -> list[float]:
del prompts, trainer_state
rewards: list[float] = []
for completion in completions:
text = _completion_text(completion)
score = 0.0
json_text, leading_text, trailing_text = _extract_json_candidate_parts(text)
if json_text is None:
if "{" in text:
score += 0.05
if "<think>" in text:
score -= 0.10
rewards.append(_clamp(score, -0.25, 1.0))
continue
try:
json.loads(json_text)
score += 0.60
except Exception:
score += 0.20
outer_text = (leading_text + trailing_text).strip()
if not outer_text:
score += 0.25
else:
ratio = len(json_text) / max(1, len(_strip_code_fences(text).strip()))
score += 0.15 * ratio
score += 0.10 * _compactness_score(len(json_text), 4500)
if "<think>" in text:
score -= 0.15
if "```" in text:
score -= 0.05
rewards.append(_clamp(score, -0.25, 1.0))
return rewards
def dm_schema_reward(
*,
prompts: list[Any],
completions: list[Any],
target_ratio: list[float] | None = None,
trainer_state: Any,
**_: Any,
) -> list[float]:
del prompts, trainer_state
target_ratio = target_ratio or [None] * len(completions)
rewards: list[float] = []
for completion, requested_ratio in zip(completions, target_ratio, strict=True):
payload = _try_parse_completion_json(_completion_text(completion))
if not isinstance(payload, dict):
rewards.append(0.0)
continue
rewards.append(_dm_structural_prior_score(payload, requested_ratio))
return rewards
def dm_validation_reward(
*,
prompts: list[Any],
completions: list[Any],
trainer_state: Any,
**_: Any,
) -> list[float]:
del prompts, trainer_state
rewards: list[float] = []
for completion in completions:
payload = _try_parse_completion_json(_completion_text(completion))
if not isinstance(payload, dict):
rewards.append(0.0)
continue
try:
WorldDefinition.model_validate(payload)
rewards.append(1.0)
except Exception as exc:
error_list = exc.errors() if hasattr(exc, "errors") else []
rewards.append(_validation_error_score(error_list))
return rewards
def dm_compile_prior_reward(
*,
prompts: list[Any],
completions: list[Any],
trainer_state: Any,
**_: Any,
) -> list[float]:
del prompts, trainer_state
rewards: list[float] = []
for completion in completions:
try:
world = _load_dm_world_definition(_completion_text(completion), allow_repair=True)
except Exception as exc:
rewards.append(_compile_error_penalty(str(exc)))
continue
try:
validate_and_normalize(world)
rewards.append(1.0)
except Exception as exc:
rewards.append(_compile_error_penalty(str(exc)))
return rewards
def _bind_dm_reward_function(
*,
artifacts_root: str | None,
hero_policy: Any,
interface_provider: str | None,
interface_model: str | None,
interface_narrate: bool,
interface_translation_mode: str | None = None,
hero_max_game_steps: int,
hero_max_tool_calls: int,
) -> Any:
return _bind_dm_rollout_reward(
dm_reward_function,
artifacts_root=artifacts_root,
hero_policy=hero_policy,
interface_provider=interface_provider,
interface_model=interface_model,
interface_narrate=interface_narrate,
interface_translation_mode=interface_translation_mode,
hero_max_game_steps=hero_max_game_steps,
hero_max_tool_calls=hero_max_tool_calls,
)
def _bind_dm_rollout_reward(
reward_impl: Callable[..., list[float]],
*,
artifacts_root: str | None,
hero_policy: Any,
interface_provider: str | None,
interface_model: str | None,
interface_narrate: bool,
interface_translation_mode: str | None = None,
hero_max_game_steps: int,
hero_max_tool_calls: int,
) -> Any:
def reward_func(**kwargs: Any) -> list[float]:
return reward_impl(
artifacts_root=artifacts_root,
hero_policy=hero_policy,
interface_provider=interface_provider,
interface_model=interface_model,
interface_narrate=interface_narrate,
interface_translation_mode=interface_translation_mode,
hero_max_game_steps=hero_max_game_steps,
hero_max_tool_calls=hero_max_tool_calls,
**kwargs,
)
reward_func.__name__ = reward_impl.__name__
return reward_func
def _make_named_overlong_reward(*, name: str, max_completion_len: int) -> Callable[..., list[float]] | None:
if get_soft_overlong_punishment is None:
return None
soft_punish_cache = max(16, min(64, max_completion_len // 4))
reward_func = get_soft_overlong_punishment(max_completion_len=max_completion_len, soft_punish_cache=soft_punish_cache)
reward_func.__name__ = name
return reward_func
def _canonicalize_qwen_chat_template(tokenizer: Any) -> Any:
chat_template = getattr(tokenizer, "chat_template", "") or ""
if qwen3_chat_template is None:
return tokenizer
if "<|im_start|>" not in chat_template or "<|im_end|>" not in chat_template:
return tokenizer
tokenizer.chat_template = qwen3_chat_template
return tokenizer
def _chat_template_kwargs(tokenizer: Any) -> dict[str, Any] | None:
if not hasattr(tokenizer, "apply_chat_template"):
return None
try:
tokenizer.apply_chat_template(
[{"role": "user", "content": "ping"}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except Exception:
return None
return {"enable_thinking": False}
def _ensure_tool_response_schema(tokenizer: Any) -> Any:
tokenizer = _canonicalize_qwen_chat_template(tokenizer)
chat_template = getattr(tokenizer, "chat_template", "") or ""
if qwen3_chat_template is None or qwen3_schema is None:
return tokenizer
if not hasattr(tokenizer, "parse_response"):
return tokenizer
if "<tool_call>" not in chat_template or "<|im_start|>" not in chat_template:
return tokenizer
tokenizer.chat_template = qwen3_chat_template
if getattr(tokenizer, "response_schema", None) is not None:
return tokenizer
tokenizer.response_schema = qwen3_schema
return tokenizer
def hero_tool_format_reward(
*,
prompts: list[Any],
completions: list[Any],
trainer_state: Any,
**_: Any,
) -> list[float]:
del prompts, trainer_state
rewards: list[float] = []
for completion in completions:
text = _completion_text(completion)
tool_calls = _completion_tool_calls(completion)
score = 0.0
if len(tool_calls) == 1:
call = tool_calls[0]
score += 0.65 if call["source"] == "tool_call" else 0.30
if call["name"] in _HERO_TOOL_NAMES:
score += 0.15
outer_text = _normalize_outer_completion_text(text)
if not outer_text:
score += 0.15
else:
score += 0.10 * (1.0 - min(1.0, len(outer_text) / max(1, len(text.strip()))))
elif len(tool_calls) > 1:
score += 0.20
if all(call["name"] in _HERO_TOOL_NAMES for call in tool_calls):
score += 0.05
else:
if "<tool_call>" in text:
score += 0.05
elif '{"action"' in text.replace(" ", ""):
score += 0.10
if "<think>" in text:
score -= 0.15
if "```" in text:
score -= 0.05
rewards.append(_clamp(score, -0.25, 1.0))
return rewards
def hero_action_semantics_reward(
*,
prompts: list[Any],
completions: list[Any],
trainer_state: Any,
**_: Any,
) -> list[float]:
del prompts, trainer_state
rewards: list[float] = []
for completion in completions:
tool_calls = _completion_tool_calls(completion)
if len(tool_calls) != 1:
rewards.append(0.10 if len(tool_calls) > 1 else 0.0)
continue
tool_call = tool_calls[0]
tool_name = tool_call["name"]
arguments = tool_call["arguments"]
if tool_name == "act":
reward = _hero_act_semantics_reward(arguments)
elif tool_name == "scratchpad_read":
reward = 1.0 if not arguments else 0.80
elif tool_name == "scratchpad_write":
reward = _hero_scratchpad_write_reward(arguments)
else:
reward = -0.25
if tool_call["source"] != "tool_call":
reward *= 0.85
rewards.append(_clamp(reward, -0.25, 1.0))
return rewards
def hero_reward_function(
*,
prompts: list[Any],
completions: list[Any],
environments: list[HeroToolEnvironment],
trainer_state: Any,
**_: Any,
) -> list[float]:
del prompts, completions, trainer_state
rewards: list[float] = []
for environment in environments:
reward = environment._cumulative_reward()
if not environment._episode_done():
reward -= 0.05
rewards.append(reward)
return rewards
def create_dm_grpo_trainer(
config: GRPOLaunchConfig,
*,
target_ratios: list[float] | None = None,
artifacts_root: Path | None = None,
closed_loop: DMClosedLoopConfig | None = None,
):
_require_training_dependencies()
closed_loop = closed_loop or DMClosedLoopConfig()
rows = build_dm_grpo_dataset(num_prompts=config.num_prompts, target_ratios=target_ratios)
dataset = Dataset.from_list(rows)
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=config.trust_remote_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer = _canonicalize_qwen_chat_template(tokenizer)
chat_template_kwargs = _chat_template_kwargs(tokenizer)
hero_client_config = resolve_structured_client_config(
"hero",
provider=closed_loop.hero_provider, # type: ignore[arg-type]
model_name=closed_loop.hero_model,
adapter_path=closed_loop.hero_adapter_path,
)
hero_policy = HeroLLMPolicy(
create_structured_client(hero_client_config),
model_name=hero_client_config.model_name,
)
reward_funcs: list[Any] = [
dm_json_format_reward,
dm_schema_reward,
dm_validation_reward,
dm_compile_prior_reward,
_bind_dm_rollout_reward(
dm_hero_success_reward,
artifacts_root=str(artifacts_root) if artifacts_root is not None else None,
hero_policy=hero_policy,
interface_provider=closed_loop.interface_provider,
interface_model=closed_loop.interface_model,
interface_narrate=closed_loop.interface_narrate,
interface_translation_mode=closed_loop.interface_translation_mode,
hero_max_game_steps=closed_loop.hero_max_game_steps,
hero_max_tool_calls=closed_loop.hero_max_tool_calls,
),
_bind_dm_rollout_reward(
dm_hero_efficiency_reward,
artifacts_root=str(artifacts_root) if artifacts_root is not None else None,
hero_policy=hero_policy,
interface_provider=closed_loop.interface_provider,
interface_model=closed_loop.interface_model,
interface_narrate=closed_loop.interface_narrate,
interface_translation_mode=closed_loop.interface_translation_mode,
hero_max_game_steps=closed_loop.hero_max_game_steps,
hero_max_tool_calls=closed_loop.hero_max_tool_calls,
),
_bind_dm_rollout_reward(
dm_hero_cleanliness_reward,
artifacts_root=str(artifacts_root) if artifacts_root is not None else None,
hero_policy=hero_policy,
interface_provider=closed_loop.interface_provider,
interface_model=closed_loop.interface_model,
interface_narrate=closed_loop.interface_narrate,
interface_translation_mode=closed_loop.interface_translation_mode,
hero_max_game_steps=closed_loop.hero_max_game_steps,
hero_max_tool_calls=closed_loop.hero_max_tool_calls,
),
_bind_dm_reward_function(
artifacts_root=str(artifacts_root) if artifacts_root is not None else None,
hero_policy=hero_policy,
interface_provider=closed_loop.interface_provider,
interface_model=closed_loop.interface_model,
interface_narrate=closed_loop.interface_narrate,
interface_translation_mode=closed_loop.interface_translation_mode,
hero_max_game_steps=closed_loop.hero_max_game_steps,
hero_max_tool_calls=closed_loop.hero_max_tool_calls,
),
]
reward_weights = [0.25, 0.20, 0.50, 0.45, 0.0, 0.0, 0.0, 1.0]
overlong_reward = _make_named_overlong_reward(name="dm_overlong_reward", max_completion_len=config.max_completion_length)
if overlong_reward is not None:
reward_funcs.append(overlong_reward)
reward_weights.append(0.15)
model, peft_config, include_model_init_kwargs = _build_trainable_model(config)
return GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=_build_grpo_config(
config,
max_tool_calling_iterations=None,
chat_template_kwargs=chat_template_kwargs,
reward_weights=reward_weights,
include_model_init_kwargs=include_model_init_kwargs,
),
train_dataset=dataset,
processing_class=tokenizer,
peft_config=peft_config,
)
def create_hero_grpo_trainer(
config: GRPOLaunchConfig,
*,
world_input: dict[str, Any] | None = None,
artifacts_root: Path | None = None,
interface_provider: str | None = None,
interface_model: str | None = None,
interface_narrate: bool = False,
interface_translation_mode: str | None = None,
max_game_steps: int = 40,
max_tool_calls: int = 80,
max_tool_calling_iterations: int = 32,
):
_require_training_dependencies()
rows = build_hero_grpo_dataset(
num_prompts=config.num_prompts,
world_input=world_input,
max_game_steps=max_game_steps,
max_tool_calls=max_tool_calls,
)
dataset = Dataset.from_list(rows)
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=config.trust_remote_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer = _ensure_tool_response_schema(tokenizer)
chat_template_kwargs = _chat_template_kwargs(tokenizer)
environment_factory = lambda: HeroToolEnvironment(
artifacts_root=artifacts_root,
interface_provider=interface_provider,
interface_model=interface_model,
interface_narrate=interface_narrate,
interface_translation_mode=interface_translation_mode,
)
reward_funcs: list[Any] = [
hero_tool_format_reward,
hero_action_semantics_reward,
hero_reward_function,
]
reward_weights = [0.40, 0.30, 1.0]
overlong_reward = _make_named_overlong_reward(
name="hero_overlong_reward",
max_completion_len=config.max_completion_length,
)
if overlong_reward is not None:
reward_funcs.append(overlong_reward)
reward_weights.append(0.15)
model, peft_config, include_model_init_kwargs = _build_trainable_model(config)
return GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=_build_grpo_config(
config,
max_tool_calling_iterations=max_tool_calling_iterations,
chat_template_kwargs=chat_template_kwargs,
reward_weights=reward_weights,
include_model_init_kwargs=include_model_init_kwargs,
),
train_dataset=dataset,
processing_class=tokenizer,
peft_config=peft_config,
environment_factory=environment_factory,
)
def run_dm_grpo(
config: GRPOLaunchConfig,
*,
target_ratios: list[float] | None = None,
artifacts_root: Path | None = None,
closed_loop: DMClosedLoopConfig | None = None,
) -> Path:
trainer = create_dm_grpo_trainer(
config,
target_ratios=target_ratios,
artifacts_root=artifacts_root,
closed_loop=closed_loop,
)
trainer.train()
trainer.save_model()
return config.output_dir
def run_hero_grpo(
config: GRPOLaunchConfig,
*,
world_path: Path | None = None,
artifacts_root: Path | None = None,
interface_provider: str | None = None,
interface_model: str | None = None,
interface_narrate: bool = False,
interface_translation_mode: str | None = None,
max_game_steps: int = 40,
max_tool_calls: int = 80,
max_tool_calling_iterations: int = 32,
) -> Path:
world_input = load_world(str(world_path)) if world_path is not None else None
trainer = create_hero_grpo_trainer(
config,
world_input=world_input,
artifacts_root=artifacts_root,
interface_provider=interface_provider,
interface_model=interface_model,
interface_narrate=interface_narrate,
interface_translation_mode=interface_translation_mode,
max_game_steps=max_game_steps,
max_tool_calls=max_tool_calls,
max_tool_calling_iterations=max_tool_calling_iterations,
)
trainer.train()
trainer.save_model()
return config.output_dir
def _build_lora_config(config: GRPOLaunchConfig):
_require_training_dependencies()
return LoraConfig(
r=config.rank,
lora_alpha=config.alpha,
lora_dropout=config.dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
def _build_trainable_model(config: GRPOLaunchConfig) -> tuple[Any, Any | None, bool]:
_require_training_dependencies()
if config.resume_adapter_path is None:
return config.model_name, _build_lora_config(config), True
from peft import PeftModel
from transformers import AutoModelForCausalLM
adapter_path = Path(config.resume_adapter_path)
if not adapter_path.exists():
raise FileNotFoundError(f"resume_adapter_path does not exist: {adapter_path}")
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
cache_dir=os.getenv("HF_HOME"),
token=os.getenv("HF_TOKEN"),
**_model_init_kwargs(config),
)
model = PeftModel.from_pretrained(model, str(adapter_path), is_trainable=True)
model.train()
return model, None, False
def _build_grpo_config(
config: GRPOLaunchConfig,
*,
max_tool_calling_iterations: int | None,
chat_template_kwargs: dict[str, Any] | None,
reward_weights: list[float] | None,
include_model_init_kwargs: bool = True,
):
_require_training_dependencies()
_require_vllm_if_requested(config)
config.output_dir.mkdir(parents=True, exist_ok=True)
report_to = ["wandb"] if config.use_wandb else []
model_init_kwargs = _model_init_kwargs(config) if include_model_init_kwargs else None
return GRPOConfig(
output_dir=str(config.output_dir),
run_name=config.run_name,
report_to=report_to,
max_steps=config.max_steps,
learning_rate=config.learning_rate,
per_device_train_batch_size=config.per_device_train_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
num_generations=config.num_generations,
max_completion_length=config.max_completion_length,
temperature=config.temperature,
top_p=config.top_p,
top_k=config.top_k,
min_p=config.min_p,
repetition_penalty=config.repetition_penalty,
logging_steps=config.logging_steps,
save_steps=config.save_steps,
seed=config.seed,
bf16=torch.cuda.is_available(),
gradient_checkpointing=True,
remove_unused_columns=False,
loss_type=config.loss_type,
importance_sampling_level=config.importance_sampling_level,
use_transformers_paged=config.use_transformers_paged,
cache_implementation=config.cache_implementation,
use_vllm=config.use_vllm,
vllm_mode=config.vllm_mode,
vllm_gpu_memory_utilization=config.vllm_gpu_memory_utilization,
vllm_enable_sleep_mode=config.vllm_enable_sleep_mode,
log_completions=True,
log_unique_prompts=True,
num_completions_to_print=1,
max_tool_calling_iterations=max_tool_calling_iterations,
chat_template_kwargs=chat_template_kwargs,
reward_weights=reward_weights,
mask_truncated_completions=True,
model_init_kwargs=model_init_kwargs,
)
def _cached_dm_rollout_metrics(
*,
completion: Any,
requested_ratio: float,
trainer_state: Any,
completion_index: int,
hero_policy: Any,
interface_provider: str | None,
interface_model: str | None,
interface_narrate: bool,
interface_translation_mode: str | None,
hero_max_game_steps: int,
hero_max_tool_calls: int,
artifacts_root: str | None,
) -> DMRolloutMetrics:
global _DM_ROLLOUT_CACHE_STEP, _DM_ROLLOUT_CACHE
step = int(getattr(trainer_state, "global_step", 0) or 0)
if step != _DM_ROLLOUT_CACHE_STEP:
_DM_ROLLOUT_CACHE_STEP = step
_DM_ROLLOUT_CACHE = {}
completion_text = _completion_text(completion)
key = (
step,
completion_index,
requested_ratio,
hashlib.sha1(completion_text.encode("utf-8")).hexdigest(),
id(hero_policy),
interface_provider,
interface_model,
interface_narrate,
interface_translation_mode,
hero_max_game_steps,
hero_max_tool_calls,
artifacts_root,
)
cached = _DM_ROLLOUT_CACHE.get(key)
if cached is not None:
return cached
metrics = _evaluate_dm_rollout(
completion_text=completion_text,
requested_ratio=requested_ratio,
trainer_state=trainer_state,
completion_index=completion_index,
hero_policy=hero_policy,
interface_provider=interface_provider,
interface_model=interface_model,
interface_narrate=interface_narrate,
interface_translation_mode=interface_translation_mode,
hero_max_game_steps=hero_max_game_steps,
hero_max_tool_calls=hero_max_tool_calls,
artifacts_root=artifacts_root,
)
_DM_ROLLOUT_CACHE[key] = metrics
return metrics
def _evaluate_dm_rollout(
*,
completion_text: str,
requested_ratio: float,
trainer_state: Any,
completion_index: int,
hero_policy: Any,
interface_provider: str | None,
interface_model: str | None,
interface_narrate: bool,
interface_translation_mode: str | None,
hero_max_game_steps: int,
hero_max_tool_calls: int,
artifacts_root: str | None,
) -> DMRolloutMetrics:
try:
world = _load_dm_world_definition(completion_text, allow_repair=True)
except Exception as exc:
return DMRolloutMetrics(
reward=_compile_error_penalty(str(exc)),
compile_error=str(exc),
requested_ratio=requested_ratio,
player_won=False,
steps_taken=None,
min_steps=None,
ratio=None,
efficiency_score=0.0,
quality_score=0.0,
invalid_command_count=0,
wrong_submit_count=0,
hero_player_won=False,
hero_total_reward=0.0,
hero_dense_return=0.0,
hero_steps_taken=0,
hero_tool_calls_total=0,
hero_policy_error=None,
)
interface_adapter = build_interface_adapter(
resolve_interface_config(
provider=interface_provider, # type: ignore[arg-type]
model_name=interface_model,
narrate_observations=interface_narrate,
translation_mode=interface_translation_mode, # type: ignore[arg-type]
)
)
env = DMEnvironment(
artifacts_root=_dm_reward_artifacts_dir(
artifacts_root=artifacts_root,
trainer_state=trainer_state,
completion_index=completion_index,
),
interface_adapter=interface_adapter,
)
runner = HeroRunner(
policy=hero_policy,
max_game_steps=hero_max_game_steps,
max_tool_calls=hero_max_tool_calls,
)
try:
env.reset(difficulty_hint=requested_ratio)
result = env.step(world, runner=runner)
observation = result.observation
reward = float(observation.reward or 0.0)
if observation.compile_error is not None:
reward = _compile_error_penalty(observation.compile_error)
elif abs(world.meta.difficulty_target - requested_ratio) > 1e-6:
reward -= 0.25
feedback = observation.feedback
breakdown = observation.reward_breakdown
hero_stats = runner.episode_stats
return DMRolloutMetrics(
reward=max(-1.0, reward),
compile_error=observation.compile_error,
requested_ratio=requested_ratio,
player_won=bool(observation.player_won),
steps_taken=observation.steps_taken,
min_steps=observation.min_steps,
ratio=observation.ratio,
efficiency_score=0.0 if breakdown is None or breakdown.efficiency_score is None else float(breakdown.efficiency_score),
quality_score=0.0 if breakdown is None else float(breakdown.quality_score),
invalid_command_count=0 if feedback is None else int(feedback.invalid_command_count),
wrong_submit_count=0 if feedback is None else int(feedback.wrong_submit_count),
hero_player_won=bool(observation.player_won) if hero_stats is None else bool(hero_stats.player_won),
hero_total_reward=0.0 if hero_stats is None else float(hero_stats.total_reward),
hero_dense_return=0.0 if hero_stats is None else float(hero_stats.dense_return),
hero_steps_taken=0 if hero_stats is None else int(hero_stats.steps_taken),
hero_tool_calls_total=0 if hero_stats is None else int(hero_stats.tool_calls_total),
hero_policy_error=runner.last_error,
)
except Exception as exc:
return DMRolloutMetrics(
reward=_compile_error_penalty(str(exc)),
compile_error=str(exc),
requested_ratio=requested_ratio,
player_won=False,
steps_taken=None,
min_steps=None,
ratio=None,
efficiency_score=0.0,
quality_score=0.0,
invalid_command_count=0,
wrong_submit_count=0,
hero_player_won=False,
hero_total_reward=0.0,
hero_dense_return=0.0,
hero_steps_taken=0,
hero_tool_calls_total=0,
hero_policy_error=runner.last_error,
)
def _model_init_kwargs(config: GRPOLaunchConfig) -> dict[str, Any]:
model_init_kwargs: dict[str, Any] = {
"trust_remote_code": config.trust_remote_code,
}
quantization_config = _build_quantization_config(config)
if quantization_config is not None:
model_init_kwargs["quantization_config"] = quantization_config
if torch.cuda.is_available():
model_init_kwargs["torch_dtype"] = torch.bfloat16
return model_init_kwargs
def _build_quantization_config(config: GRPOLaunchConfig):
_require_training_dependencies()
if not config.load_in_4bit or not torch.cuda.is_available():
return None
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
def _completion_text(completion: Any) -> str:
if isinstance(completion, str):
return completion
if isinstance(completion, list):
parts: list[str] = []
for message in completion:
if isinstance(message, dict) and message.get("role") == "assistant":
content = message.get("content")
if isinstance(content, str):
parts.append(content)
return "\n".join(parts)
return str(completion)
def _extract_json_object(text: str) -> str:
json_text, _, _ = _extract_json_candidate_parts(text)
if json_text is None:
raise ValueError("Completion did not contain a JSON object.")
return json_text
def _extract_json_candidate_parts(text: str) -> tuple[str | None, str, str]:
cleaned = _strip_code_fences(text).strip()
span = _find_json_object_span(cleaned)
if span is None:
return None, cleaned, ""
start, end = span
return cleaned[start:end], cleaned[:start], cleaned[end:]
def _try_parse_completion_json(text: str) -> Any | None:
json_text, _, _ = _extract_json_candidate_parts(text)
if json_text is None:
return None
try:
return json.loads(json_text)
except Exception:
return None
def _repair_dm_candidate_payload(payload: Any) -> Any:
if isinstance(payload, list):
return [_repair_dm_candidate_payload(item) for item in payload]
if not isinstance(payload, dict):
return payload
node_type = payload.get("type")
repaired: dict[str, Any] = {}
for key, value in payload.items():
normalized_key = "requires_step_ids" if key == "requires_step_id" else key
repaired[normalized_key] = _repair_dm_candidate_payload(value)
requires_step_ids = repaired.get("requires_step_ids")
if requires_step_ids is None and "requires_step_ids" in repaired:
repaired["requires_step_ids"] = []
elif isinstance(requires_step_ids, str):
repaired["requires_step_ids"] = [requires_step_ids]
if "open" not in repaired and "is_open" in repaired:
repaired["open"] = repaired.pop("is_open")
if "locked" not in repaired and "is_locked" in repaired:
repaired["locked"] = repaired.pop("is_locked")
if node_type in {"container", "door"}:
closed = repaired.pop("closed", None)
if isinstance(closed, bool) and "open" not in repaired:
repaired["open"] = not closed
if node_type == "fixture":
if "reveals_item_id" not in repaired and "reveal_item_id" in repaired:
repaired["reveals_item_id"] = repaired.pop("reveal_item_id")
if "reveals_readable_id" not in repaired and "reveal_readable_id" in repaired:
repaired["reveals_readable_id"] = repaired.pop("reveal_readable_id")
if node_type == "npc":
if "requires_item_id" not in repaired and "trade_requires_item_id" in repaired:
repaired["requires_item_id"] = repaired.pop("trade_requires_item_id")
if "gives_item_id" not in repaired and "trade_item_id" in repaired:
repaired["gives_item_id"] = repaired.pop("trade_item_id")
if "gives_clue_id" not in repaired and "trade_clue_id" in repaired:
repaired["gives_clue_id"] = repaired.pop("trade_clue_id")
if "subtype" not in repaired and repaired.get("type") in _DM_ALLOWED_ITEM_TYPES and "start_node_id" in repaired:
repaired["subtype"] = repaired.pop("type")
if "id" not in repaired and "clue_id" in repaired and "text" in repaired:
repaired["id"] = repaired.pop("clue_id")
if "input_item_ids" not in repaired and "input_item_a_id" in repaired and "input_item_b_id" in repaired:
repaired["input_item_ids"] = [repaired.pop("input_item_a_id"), repaired.pop("input_item_b_id")]
if node_type == "container":
repaired.pop("contains_items", None)
if "output_item_id" in repaired and (
"input_item_ids" in repaired or ("input_item_a_id" in repaired and "input_item_b_id" in repaired)
):
repaired.pop("label", None)
repaired.pop("description", None)
if node_type in {"location", "junction", "door"}:
repaired.pop("parent_id", None)
return repaired
def _repair_dm_world_payload(payload: dict[str, Any]) -> dict[str, Any]:
repaired = _repair_dm_candidate_payload(payload)
if not isinstance(repaired, dict):
return payload
meta = repaired.get("meta")
if not isinstance(meta, dict):
meta = {}
else:
meta = dict(meta)
title = meta.get("title")
if not isinstance(title, str) or not title.strip():
meta["title"] = _infer_dm_world_title(repaired)
start_node_id = meta.get("start_node_id")
if not isinstance(start_node_id, str) or not start_node_id:
inferred_start = _infer_dm_start_node_id(repaired.get("nodes"))
if inferred_start is not None:
meta["start_node_id"] = inferred_start
win_condition = meta.get("win_condition")
if not isinstance(win_condition, dict):
win_condition = {}
else:
win_condition = dict(win_condition)
if not isinstance(win_condition.get("type"), str) or not win_condition.get("type"):
win_condition["type"] = "deduce"
if not isinstance(win_condition.get("target_npc_id"), str) or not win_condition.get("target_npc_id"):
inferred_guardian = _infer_dm_guardian_npc_id(repaired)
if inferred_guardian is not None:
win_condition["target_npc_id"] = inferred_guardian
if not isinstance(win_condition.get("answer_string"), str) or not win_condition.get("answer_string"):
inferred_answer = _infer_dm_answer_string(repaired.get("quest_chain"))
if inferred_answer:
win_condition["answer_string"] = inferred_answer
if win_condition:
meta["win_condition"] = win_condition
_repair_guardian_trade_fields(repaired, guardian_id=win_condition.get("target_npc_id"))
_repair_submit_actions(repaired)
_repair_door_lock_keys_from_edges(repaired)
_repair_missing_item_references(repaired)
_repair_produced_item_placements(repaired, default_start_node_id=meta.get("start_node_id"))
_repair_required_key_item_subtypes(repaired)
_repair_duplicate_recipe_ids(repaired)
_repair_guardian_room_access(repaired, guardian_id=win_condition.get("target_npc_id"), start_node_id=meta.get("start_node_id"))
_repair_missing_readable_clue_ids(repaired)
_repair_missing_clue_sources(repaired, guardian_id=win_condition.get("target_npc_id"))
_repair_take_action_aliases(repaired)
_repair_take_sources_from_room_prereqs(repaired)
_repair_locked_room_entry_steps(repaired)
_repair_missing_take_steps(repaired)
_repair_guardian_ending(
repaired,
guardian_id=win_condition.get("target_npc_id"),
answer_string=win_condition.get("answer_string"),
)
_repair_guardian_room_access(repaired, guardian_id=win_condition.get("target_npc_id"), start_node_id=meta.get("start_node_id"))
repaired["meta"] = meta
return repaired
def _infer_dm_world_title(payload: dict[str, Any]) -> str:
meta = payload.get("meta")
if isinstance(meta, dict):
for key in ("name", "world_name"):
value = meta.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
nodes = payload.get("nodes")
if isinstance(nodes, list):
for node in nodes:
if not isinstance(node, dict) or node.get("type") not in {"location", "junction"}:
continue
label = node.get("label")
if isinstance(label, str) and label.strip():
return f"The {label.strip()}"
return "The Hidden Vault"
def _infer_dm_start_node_id(nodes: Any) -> str | None:
if not isinstance(nodes, list):
return None
for node in nodes:
if not isinstance(node, dict) or node.get("type") not in {"location", "junction"}:
continue
node_id = node.get("id")
if isinstance(node_id, str) and node_id:
return node_id
return None
def _infer_dm_guardian_npc_id(payload: dict[str, Any]) -> str | None:
quest_chain = payload.get("quest_chain")
if isinstance(quest_chain, list):
for step in reversed(quest_chain):
action = step.get("action") if isinstance(step, dict) else None
npc_id = _extract_single_action_argument(action, "talk")
if npc_id:
return npc_id
nodes = payload.get("nodes")
if not isinstance(nodes, list):
return None
first_npc_id: str | None = None
for node in nodes:
if not isinstance(node, dict) or node.get("type") != "npc":
continue
node_id = node.get("id")
if not isinstance(node_id, str) or not node_id:
continue
if first_npc_id is None:
first_npc_id = node_id
if "guardian" in node_id:
return node_id
return first_npc_id
def _infer_dm_answer_string(quest_chain: Any) -> str | None:
if not isinstance(quest_chain, list):
return None
for step in reversed(quest_chain):
action = step.get("action") if isinstance(step, dict) else None
answer = _extract_single_action_argument(action, "submit")
if answer is None:
continue
normalized = normalize_answer_text(answer)
if normalized:
return normalized
return None
def _repair_missing_readable_clue_ids(payload: dict[str, Any]) -> None:
nodes = payload.get("nodes")
clues = payload.get("clues")
if not isinstance(nodes, list) or not isinstance(clues, list):
return
clue_ids = [clue.get("id") for clue in clues if isinstance(clue, dict) and isinstance(clue.get("id"), str)]
if not clue_ids:
return
used_clue_ids = {
node.get("clue_id")
for node in nodes
if isinstance(node, dict) and node.get("type") == "readable" and isinstance(node.get("clue_id"), str)
}
available_clue_ids = [clue_id for clue_id in clue_ids if clue_id not in used_clue_ids]
if not available_clue_ids:
return
for node in nodes:
if not isinstance(node, dict) or node.get("type") != "readable" or node.get("clue_id"):
continue
if not available_clue_ids:
return
node["clue_id"] = available_clue_ids.pop(0)
def _repair_guardian_trade_fields(payload: dict[str, Any], *, guardian_id: Any) -> None:
if not isinstance(guardian_id, str) or not guardian_id:
return
nodes = payload.get("nodes")
if not isinstance(nodes, list):
return
for node in nodes:
if not isinstance(node, dict) or node.get("type") != "npc" or node.get("id") != guardian_id:
continue
node["requires_item_id"] = None
node["gives_item_id"] = None
node["gives_clue_id"] = None
return
def _repair_submit_actions(payload: dict[str, Any]) -> None:
quest_chain = payload.get("quest_chain")
if not isinstance(quest_chain, list):
return
for step in quest_chain:
if not isinstance(step, dict):
continue
action = step.get("action")
answer = _extract_single_action_argument(action, "submit")
if answer is None:
continue
if action == f'submit("{answer}")':
continue
step["action"] = f'submit("{normalize_answer_text(answer)}")'
def _repair_door_lock_keys_from_edges(payload: dict[str, Any]) -> None:
nodes = payload.get("nodes")
edges = payload.get("edges")
if not isinstance(nodes, list) or not isinstance(edges, list):
return
door_ids = [
node.get("id")
for node in nodes
if isinstance(node, dict) and node.get("type") == "door" and isinstance(node.get("id"), str)
]
sole_door_id = door_ids[0] if len(door_ids) == 1 else None
inferred_keys: dict[str, str] = {}
for edge in edges:
if not isinstance(edge, dict):
continue
door_node_id = edge.get("door_node_id")
required_item_id = edge.get("required_item_id")
if sole_door_id is not None and isinstance(door_node_id, str) and door_node_id not in door_ids:
edge["door_node_id"] = sole_door_id
door_node_id = sole_door_id
if not isinstance(door_node_id, str) or not isinstance(required_item_id, str):
continue
existing_key = inferred_keys.get(door_node_id)
if existing_key is None or existing_key == required_item_id:
inferred_keys[door_node_id] = required_item_id
if not inferred_keys:
return
for node in nodes:
if not isinstance(node, dict) or node.get("type") != "door":
continue
door_id = node.get("id")
if isinstance(door_id, str) and door_id in inferred_keys:
node["lock_key_id"] = inferred_keys[door_id]
def _repair_required_key_item_subtypes(payload: dict[str, Any]) -> None:
items = payload.get("items")
edges = payload.get("edges")
nodes = payload.get("nodes")
if not isinstance(items, list):
return
required_key_ids: set[str] = set()
if isinstance(edges, list):
for edge in edges:
if not isinstance(edge, dict):
continue
required_item_id = edge.get("required_item_id")
if isinstance(required_item_id, str) and required_item_id:
required_key_ids.add(required_item_id)
if isinstance(nodes, list):
for node in nodes:
if not isinstance(node, dict):
continue
lock_key_id = node.get("lock_key_id")
if isinstance(lock_key_id, str) and lock_key_id:
required_key_ids.add(lock_key_id)
if not required_key_ids:
return
for item in items:
if not isinstance(item, dict):
continue
item_id = item.get("id")
if isinstance(item_id, str) and item_id in required_key_ids:
item["subtype"] = "key"
def _repair_duplicate_recipe_ids(payload: dict[str, Any]) -> None:
recipes = payload.get("recipes")
if not isinstance(recipes, list):
return
protected_ids: set[str] = set()
for key in ("nodes", "items", "clues", "quest_chain"):
values = payload.get(key)
if not isinstance(values, list):
continue
for value in values:
if not isinstance(value, dict):
continue
id_key = "step_id" if key == "quest_chain" else "id"
value_id = value.get(id_key)
if isinstance(value_id, str) and value_id:
protected_ids.add(value_id)
recipe_ids: set[str] = set()
for recipe in recipes:
if not isinstance(recipe, dict):
continue
recipe_id = recipe.get("id")
if not isinstance(recipe_id, str) or not recipe_id:
continue
if recipe_id not in protected_ids and recipe_id not in recipe_ids:
recipe_ids.add(recipe_id)
continue
new_recipe_id = _unique_world_id(recipe_id, protected_ids | recipe_ids)
recipe["id"] = new_recipe_id
recipe_ids.add(new_recipe_id)
def _repair_guardian_room_access(payload: dict[str, Any], *, guardian_id: Any, start_node_id: Any) -> None:
if not isinstance(guardian_id, str) or not guardian_id:
return
nodes = payload.get("nodes")
edges = payload.get("edges")
quest_chain = payload.get("quest_chain")
if not isinstance(nodes, list) or not isinstance(edges, list):
return
reachable_rooms = _reachable_passage_room_ids(payload, start_node_id=start_node_id)
if not reachable_rooms:
return
preferred_room_id = _infer_guardian_talk_room_from_quest(quest_chain, guardian_id=guardian_id)
if preferred_room_id not in reachable_rooms:
preferred_room_id = next(iter(sorted(reachable_rooms)))
for node in nodes:
if not isinstance(node, dict) or node.get("type") != "npc" or node.get("id") != guardian_id:
continue
parent_id = node.get("parent_id")
if isinstance(parent_id, str) and parent_id in reachable_rooms:
return
node["parent_id"] = preferred_room_id
current_guardian_room = _infer_guardian_talk_room_from_quest(quest_chain, guardian_id=guardian_id)
if current_guardian_room != preferred_room_id:
_insert_quest_step_before_guardian_talk(
quest_chain,
guardian_id=guardian_id,
step_id_base=f"go_{preferred_room_id}",
description=f"Go to {_humanize_identifier(preferred_room_id).lower()}.",
action=f"go({preferred_room_id})",
)
return
def _repair_missing_item_references(payload: dict[str, Any]) -> None:
items = payload.get("items")
nodes = payload.get("nodes")
edges = payload.get("edges")
if not isinstance(items, list):
return
existing_item_ids = {
item.get("id")
for item in items
if isinstance(item, dict) and isinstance(item.get("id"), str) and item.get("id")
}
quest_chain = payload.get("quest_chain")
def ensure_item(item_id: Any, *, subtype: str, start_node_id: str | None) -> None:
if not isinstance(item_id, str) or not item_id or item_id in existing_item_ids:
return
inferred_start_node_id = _infer_item_start_node_from_quest(quest_chain, item_id) or start_node_id
items.append(
{
"id": item_id,
"label": _humanize_identifier(item_id),
"description": f"A {_humanize_identifier(item_id).lower()} needed to solve the dungeon.",
"subtype": subtype,
"start_node_id": inferred_start_node_id,
}
)
existing_item_ids.add(item_id)
default_start_node_id = _infer_dm_start_node_id(payload.get("nodes"))
if isinstance(edges, list):
for edge in edges:
if not isinstance(edge, dict):
continue
ensure_item(edge.get("required_item_id"), subtype="key", start_node_id=default_start_node_id)
if not isinstance(nodes, list):
return
for node in nodes:
if not isinstance(node, dict):
continue
node_type = node.get("type")
if node_type in {"container", "door"}:
ensure_item(node.get("lock_key_id"), subtype="key", start_node_id=default_start_node_id)
elif node_type == "readable":
ensure_item(
node.get("requires_item_id"),
subtype="puzzle",
start_node_id=_node_room_start_node_id(node, default_start_node_id),
)
elif node_type == "fixture":
ensure_item(
node.get("requires_item_id"),
subtype="puzzle",
start_node_id=_node_room_start_node_id(node, default_start_node_id),
)
ensure_item(node.get("reveals_item_id"), subtype="puzzle", start_node_id=None)
elif node_type == "npc":
ensure_item(
node.get("requires_item_id"),
subtype="puzzle",
start_node_id=_node_room_start_node_id(node, default_start_node_id),
)
ensure_item(node.get("gives_item_id"), subtype="puzzle", start_node_id=None)
recipes = payload.get("recipes")
if not isinstance(recipes, list):
return
for recipe in recipes:
if not isinstance(recipe, dict):
continue
input_ids = recipe.get("input_item_ids")
if isinstance(input_ids, list):
for item_id in input_ids:
ensure_item(item_id, subtype="puzzle", start_node_id=default_start_node_id)
ensure_item(recipe.get("output_item_id"), subtype="puzzle", start_node_id=None)
def _repair_produced_item_placements(payload: dict[str, Any], *, default_start_node_id: Any) -> None:
items = payload.get("items")
if not isinstance(items, list):
return
produced_item_ids: set[str] = set()
recipes = payload.get("recipes")
if isinstance(recipes, list):
for recipe in recipes:
if not isinstance(recipe, dict):
continue
output_item_id = recipe.get("output_item_id")
if isinstance(output_item_id, str) and output_item_id:
produced_item_ids.add(output_item_id)
nodes = payload.get("nodes")
if isinstance(nodes, list):
for node in nodes:
if not isinstance(node, dict):
continue
if node.get("type") == "npc":
gives_item_id = node.get("gives_item_id")
if isinstance(gives_item_id, str) and gives_item_id:
produced_item_ids.add(gives_item_id)
elif node.get("type") == "fixture":
reveals_item_id = node.get("reveals_item_id")
if isinstance(reveals_item_id, str) and reveals_item_id:
produced_item_ids.add(reveals_item_id)
start_node_id = default_start_node_id if isinstance(default_start_node_id, str) and default_start_node_id else None
for item in items:
if not isinstance(item, dict):
continue
item_id = item.get("id")
if not isinstance(item_id, str) or not item_id:
continue
if item_id in produced_item_ids:
item["start_node_id"] = None
elif item.get("start_node_id") is None and start_node_id is not None:
item["start_node_id"] = start_node_id
def _repair_missing_clue_sources(payload: dict[str, Any], *, guardian_id: Any) -> None:
clues = payload.get("clues")
nodes = payload.get("nodes")
items = payload.get("items")
quest_chain = payload.get("quest_chain")
if not isinstance(clues, list) or not isinstance(nodes, list):
return
clue_text_by_id = {
clue.get("id"): clue.get("text")
for clue in clues
if isinstance(clue, dict) and isinstance(clue.get("id"), str)
}
if not clue_text_by_id:
return
sourced_clue_ids = set()
room_ids: set[str] = set()
guardian_room_id: str | None = None
for node in nodes:
if not isinstance(node, dict):
continue
node_type = node.get("type")
if node_type in {"location", "junction"}:
node_id = node.get("id")
if isinstance(node_id, str) and node_id:
room_ids.add(node_id)
elif node_type == "readable":
clue_id = node.get("clue_id")
if isinstance(clue_id, str) and clue_id:
sourced_clue_ids.add(clue_id)
elif node_type == "npc":
clue_id = node.get("gives_clue_id")
if isinstance(clue_id, str) and clue_id:
sourced_clue_ids.add(clue_id)
if isinstance(guardian_id, str) and guardian_id and node.get("id") == guardian_id:
parent_id = node.get("parent_id")
if isinstance(parent_id, str) and parent_id:
guardian_room_id = parent_id
missing_clue_ids = [clue_id for clue_id in clue_text_by_id if clue_id not in sourced_clue_ids]
if not missing_clue_ids:
return
target_room_id = guardian_room_id or _infer_dm_start_node_id(nodes)
if not isinstance(target_room_id, str) or target_room_id not in room_ids:
target_room_id = next(iter(room_ids), None)
if target_room_id is None:
return
gating_item_id = _select_synthetic_clue_gate_item_id(items, quest_chain)
if gating_item_id is None:
if not isinstance(items, list):
return
gating_item_id = "inspection_lens"
items.append(
{
"id": gating_item_id,
"label": "Inspection Lens",
"description": "A careful lens for reading faint inscriptions.",
"subtype": "puzzle",
"start_node_id": _infer_dm_start_node_id(nodes),
}
)
existing_node_ids = {
node.get("id")
for node in nodes
if isinstance(node, dict) and isinstance(node.get("id"), str) and node.get("id")
}
existing_safe_labels = {
parser_safe_text(node.get("label"))
for node in nodes
if isinstance(node, dict) and isinstance(node.get("label"), str) and node.get("label")
}
synthetic_step_ids: list[str] = []
for clue_id in missing_clue_ids:
readable_id = _unique_world_id(f"{clue_id}_inscription", existing_node_ids)
label = _unique_world_label(f"{_humanize_identifier(clue_id)} Inscription", existing_safe_labels)
nodes.append(
{
"id": readable_id,
"type": "readable",
"label": label,
"description": f"A {label.lower()} can only be deciphered with the right tool.",
"parent_id": target_room_id,
"clue_id": clue_id,
"requires_item_id": gating_item_id,
"consumes_item": False,
"text_content": clue_text_by_id[clue_id] or f"A fragment about {_humanize_identifier(clue_id).lower()}.",
}
)
step_id = _insert_quest_step_before_guardian_talk(
quest_chain,
guardian_id=guardian_id,
step_id_base=f"inspect_{readable_id}",
description=f"Inspect the {label.lower()}.",
action=(
f"use({gating_item_id},{readable_id})"
if isinstance(gating_item_id, str) and gating_item_id
else f"read({readable_id})"
),
)
if step_id is not None:
synthetic_step_ids.append(step_id)
def _repair_take_action_aliases(payload: dict[str, Any]) -> None:
quest_chain = payload.get("quest_chain")
nodes = payload.get("nodes")
if not isinstance(quest_chain, list) or not isinstance(nodes, list):
return
fixture_by_id: dict[str, dict[str, Any]] = {}
npc_by_id: dict[str, dict[str, Any]] = {}
for node in nodes:
if not isinstance(node, dict):
continue
node_id = node.get("id")
if not isinstance(node_id, str) or not node_id:
continue
if node.get("type") == "fixture":
fixture_by_id[node_id] = node
elif node.get("type") == "npc":
npc_by_id[node_id] = node
for step in quest_chain:
if not isinstance(step, dict):
continue
arguments = _extract_action_arguments(step.get("action"), "take")
if arguments is None or len(arguments) != 2:
continue
item_id, source_id = arguments
fixture = fixture_by_id.get(source_id)
if fixture is not None and fixture.get("reveals_item_id") == item_id:
parent_id = fixture.get("parent_id")
if isinstance(parent_id, str) and parent_id:
step["action"] = f"take({item_id},{parent_id})"
continue
npc = npc_by_id.get(source_id)
if npc is None or npc.get("gives_item_id") != item_id:
continue
required_item_id = npc.get("requires_item_id")
if isinstance(required_item_id, str) and required_item_id:
step["action"] = f"give({required_item_id},{source_id})"
def _repair_take_sources_from_room_prereqs(payload: dict[str, Any]) -> None:
quest_chain = payload.get("quest_chain")
items = payload.get("items")
nodes = payload.get("nodes")
if not isinstance(quest_chain, list) or not isinstance(items, list):
return
node_types: dict[str, str] = {}
if isinstance(nodes, list):
for node in nodes:
if not isinstance(node, dict):
continue
node_id = node.get("id")
node_type = node.get("type")
if isinstance(node_id, str) and isinstance(node_type, str):
node_types[node_id] = node_type
item_by_id = {
item.get("id"): item
for item in items
if isinstance(item, dict) and isinstance(item.get("id"), str) and item.get("id")
}
step_by_id = {
step.get("step_id"): step
for step in quest_chain
if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id")
}
for step in quest_chain:
if not isinstance(step, dict):
continue
arguments = _extract_action_arguments(step.get("action"), "take")
if arguments is None or len(arguments) != 2:
continue
item_id, source_id = arguments
if node_types.get(source_id) == "container":
continue
requires_step_ids = step.get("requires_step_ids")
if not isinstance(requires_step_ids, list):
continue
required_room_id: str | None = None
for dependency in requires_step_ids:
if not isinstance(dependency, str):
continue
dependency_step = step_by_id.get(dependency)
if not isinstance(dependency_step, dict):
continue
room_id = _extract_single_action_argument(dependency_step.get("action"), "go")
if room_id:
required_room_id = room_id
if required_room_id is None or required_room_id == source_id:
continue
step["action"] = f"take({item_id},{required_room_id})"
item = item_by_id.get(item_id)
if isinstance(item, dict):
item["start_node_id"] = required_room_id
def _repair_missing_take_steps(payload: dict[str, Any]) -> None:
quest_chain = payload.get("quest_chain")
items = payload.get("items")
nodes = payload.get("nodes")
recipes = payload.get("recipes")
if not isinstance(quest_chain, list) or not isinstance(items, list):
return
item_start_nodes = {
item.get("id"): item.get("start_node_id")
for item in items
if isinstance(item, dict) and isinstance(item.get("id"), str)
}
produced_item_ids = set()
recipe_outputs: dict[frozenset[str], str] = {}
if isinstance(recipes, list):
for recipe in recipes:
if not isinstance(recipe, dict):
continue
output_item_id = recipe.get("output_item_id")
input_item_ids = recipe.get("input_item_ids")
if isinstance(output_item_id, str) and output_item_id:
produced_item_ids.add(output_item_id)
if isinstance(output_item_id, str) and isinstance(input_item_ids, list) and len(input_item_ids) == 2:
recipe_outputs[frozenset(str(item_id) for item_id in input_item_ids)] = output_item_id
npc_rewards: dict[str, str] = {}
if isinstance(nodes, list):
for node in nodes:
if not isinstance(node, dict) or node.get("type") != "npc":
continue
npc_id = node.get("id")
gives_item_id = node.get("gives_item_id")
if isinstance(npc_id, str) and npc_id and isinstance(gives_item_id, str) and gives_item_id:
produced_item_ids.add(gives_item_id)
npc_rewards[npc_id] = gives_item_id
for node in nodes:
if not isinstance(node, dict) or node.get("type") != "fixture":
continue
reveals_item_id = node.get("reveals_item_id")
if isinstance(reveals_item_id, str) and reveals_item_id:
produced_item_ids.add(reveals_item_id)
inventory: set[str] = set()
step_by_id = {
step.get("step_id"): step
for step in quest_chain
if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id")
}
index = 0
while index < len(quest_chain):
step = quest_chain[index]
if not isinstance(step, dict):
index += 1
continue
required_item_ids = _quest_required_item_ids(step.get("action"))
inserted_step = False
for item_id in required_item_ids:
if item_id in inventory or item_id in produced_item_ids:
continue
source_node_id = _infer_room_prereq_for_step(step, step_by_id) or item_start_nodes.get(item_id)
if not isinstance(source_node_id, str) or not source_node_id:
continue
new_step_id = _insert_quest_step_before_index(
quest_chain,
index=index,
step_id_base=f"take_{item_id}",
description=f"Take the {_humanize_identifier(item_id).lower()}.",
action=f"take({item_id},{source_node_id})",
allow_existing_action=True,
)
if new_step_id is not None:
inventory.add(item_id)
item = next(
(
candidate
for candidate in items
if isinstance(candidate, dict) and candidate.get("id") == item_id
),
None,
)
if isinstance(item, dict):
item["start_node_id"] = source_node_id
inserted_step = True
index += 1
if inserted_step:
step = quest_chain[index]
if not isinstance(step, dict):
index += 1
continue
arguments = _extract_action_arguments(step.get("action"), "take")
if arguments is not None and len(arguments) == 2:
inventory.add(arguments[0])
index += 1
continue
arguments = _extract_action_arguments(step.get("action"), "give")
if arguments is not None and len(arguments) == 2:
inventory.discard(arguments[0])
rewarded_item_id = npc_rewards.get(arguments[1])
if rewarded_item_id:
inventory.add(rewarded_item_id)
index += 1
continue
arguments = _extract_action_arguments(step.get("action"), "combine")
if arguments is not None and len(arguments) == 2:
inventory.discard(arguments[0])
inventory.discard(arguments[1])
output_item_id = recipe_outputs.get(frozenset(arguments))
if output_item_id:
inventory.add(output_item_id)
index += 1
continue
index += 1
def _repair_guardian_ending(payload: dict[str, Any], *, guardian_id: Any, answer_string: Any) -> None:
quest_chain = payload.get("quest_chain")
if not isinstance(quest_chain, list) or not quest_chain:
return
submit_index: int | None = None
for index in range(len(quest_chain) - 1, -1, -1):
step = quest_chain[index]
if isinstance(step, dict) and _extract_single_action_argument(step.get("action"), "submit") is not None:
submit_index = index
break
if submit_index is None:
return
submit_step = quest_chain[submit_index]
if not isinstance(submit_step, dict):
return
if isinstance(answer_string, str) and answer_string:
submit_step["action"] = f'submit("{normalize_answer_text(answer_string)}")'
talk_index = _guardian_talk_step_index(quest_chain, guardian_id=guardian_id)
if submit_index == len(quest_chain) - 1 and submit_index > 0:
penultimate = quest_chain[submit_index - 1]
if isinstance(penultimate, dict) and _extract_single_action_argument(penultimate.get("action"), "talk") == guardian_id:
return
if talk_index is None:
new_step_id = _insert_quest_step_before_index(
quest_chain,
index=submit_index,
step_id_base=f"talk_{guardian_id}",
description=f"Speak to the {_humanize_identifier(str(guardian_id)).lower()}.",
action=f"talk({guardian_id})",
allow_existing_action=True,
)
if new_step_id is not None:
submit_step["requires_step_ids"] = [new_step_id]
return
talk_step = quest_chain[talk_index]
if not isinstance(talk_step, dict):
return
if talk_index != submit_index - 1:
new_step_id = _insert_quest_step_before_index(
quest_chain,
index=submit_index,
step_id_base=talk_step.get("step_id") or f"talk_{guardian_id}",
description=talk_step.get("description") or f"Speak to the {_humanize_identifier(str(guardian_id)).lower()}.",
action=talk_step.get("action") or f"talk({guardian_id})",
allow_existing_action=True,
)
if new_step_id is not None:
submit_step["requires_step_ids"] = [new_step_id]
def _repair_locked_room_entry_steps(payload: dict[str, Any]) -> None:
quest_chain = payload.get("quest_chain")
edges = payload.get("edges")
meta = payload.get("meta")
if not isinstance(quest_chain, list) or not isinstance(edges, list) or not isinstance(meta, dict):
return
edge_by_rooms = {
(edge.get("from_node_id"), edge.get("to_node_id")): edge
for edge in edges
if isinstance(edge, dict)
}
start_node_id = meta.get("start_node_id")
step_by_id = {
step.get("step_id"): step
for step in quest_chain
if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id")
}
index = 0
while index < len(quest_chain):
step = quest_chain[index]
if not isinstance(step, dict):
index += 1
continue
target_room_id = _extract_single_action_argument(step.get("action"), "go")
if target_room_id is None:
index += 1
continue
current_room_id = _infer_room_prereq_for_step(step, step_by_id) or (
start_node_id if isinstance(start_node_id, str) else None
)
if current_room_id is None:
index += 1
continue
edge = edge_by_rooms.get((current_room_id, target_room_id))
if not isinstance(edge, dict) or not isinstance(edge.get("door_node_id"), str):
index += 1
continue
door_id = edge.get("door_node_id")
key_id = edge.get("required_item_id")
inserted = False
if isinstance(door_id, str) and isinstance(key_id, str):
unlock_action = f"unlock({door_id},{key_id})"
if not _action_exists_before_index(quest_chain, unlock_action, index):
if _insert_quest_step_before_index(
quest_chain,
index=index,
step_id_base=f"unlock_{door_id}",
description=f"Unlock the {_humanize_identifier(door_id).lower()}.",
action=unlock_action,
allow_existing_action=True,
):
inserted = True
index += 1
if isinstance(door_id, str):
open_action = f"open({door_id})"
if not _action_exists_before_index(quest_chain, open_action, index):
if _insert_quest_step_before_index(
quest_chain,
index=index,
step_id_base=f"open_{door_id}",
description=f"Open the {_humanize_identifier(door_id).lower()}.",
action=open_action,
allow_existing_action=True,
):
inserted = True
index += 1
if inserted:
step_by_id = {
candidate.get("step_id"): candidate
for candidate in quest_chain
if isinstance(candidate, dict)
and isinstance(candidate.get("step_id"), str)
and candidate.get("step_id")
}
index += 1
def _select_synthetic_clue_gate_item_id(items: Any, quest_chain: Any) -> str | None:
if not isinstance(items, list):
return None
taken_item_ids = _quest_taken_item_ids(quest_chain)
prioritized: list[tuple[int, str]] = []
for item in items:
if not isinstance(item, dict):
continue
item_id = item.get("id")
subtype = item.get("subtype")
if not isinstance(item_id, str) or not item_id:
continue
if subtype == "puzzle" and item_id in taken_item_ids:
prioritized.append((0, item_id))
elif subtype == "puzzle" and item.get("start_node_id") is not None:
prioritized.append((0, item_id))
elif subtype == "puzzle":
prioritized.append((1, item_id))
elif subtype == "key":
prioritized.append((2, item_id))
if not prioritized:
return None
prioritized.sort()
return prioritized[0][1]
def _humanize_identifier(identifier: str) -> str:
return " ".join(part.capitalize() for part in identifier.split("_") if part) or identifier
def _node_room_start_node_id(node: dict[str, Any], default_start_node_id: str | None) -> str | None:
parent_id = node.get("parent_id")
if isinstance(parent_id, str) and parent_id:
return parent_id
return default_start_node_id
def _unique_world_id(base_id: str, existing_ids: set[str]) -> str:
candidate = base_id
suffix = 2
while candidate in existing_ids:
candidate = f"{base_id}_{suffix}"
suffix += 1
existing_ids.add(candidate)
return candidate
def _unique_world_label(base_label: str, existing_safe_labels: set[str]) -> str:
candidate = base_label
suffix = 2
while parser_safe_text(candidate) in existing_safe_labels:
candidate = f"{base_label} {suffix}"
suffix += 1
existing_safe_labels.add(parser_safe_text(candidate))
return candidate
def _insert_quest_step_before_guardian_talk(
quest_chain: Any,
*,
guardian_id: Any,
step_id_base: str,
description: str,
action: str,
) -> str | None:
if not isinstance(quest_chain, list):
return None
existing_step_ids = {
step.get("step_id")
for step in quest_chain
if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id")
}
if any(isinstance(step, dict) and step.get("action") == action for step in quest_chain):
return None
talk_index: int | None = None
talk_index = _guardian_talk_step_index(quest_chain, guardian_id=guardian_id)
if talk_index is None:
return None
talk_step = quest_chain[talk_index]
return _insert_quest_step_before_index(
quest_chain,
index=talk_index,
step_id_base=step_id_base,
description=description,
action=action,
)
def _quest_taken_item_ids(quest_chain: Any) -> set[str]:
if not isinstance(quest_chain, list):
return set()
taken_item_ids: set[str] = set()
for step in quest_chain:
if not isinstance(step, dict):
continue
arguments = _extract_action_arguments(step.get("action"), "take")
if arguments is None or not arguments:
continue
item_id = arguments[0]
if item_id:
taken_item_ids.add(item_id)
return taken_item_ids
def _infer_item_start_node_from_quest(quest_chain: Any, item_id: str) -> str | None:
if not isinstance(quest_chain, list):
return None
for step in quest_chain:
if not isinstance(step, dict):
continue
arguments = _extract_action_arguments(step.get("action"), "take")
if arguments is None or len(arguments) != 2:
continue
if arguments[0] == item_id:
return arguments[1]
return None
def _infer_guardian_talk_room_from_quest(quest_chain: Any, *, guardian_id: str) -> str | None:
talk_index = _guardian_talk_step_index(quest_chain, guardian_id=guardian_id)
if talk_index is None or not isinstance(quest_chain, list):
return None
for index in range(talk_index - 1, -1, -1):
step = quest_chain[index]
if not isinstance(step, dict):
continue
room_id = _extract_single_action_argument(step.get("action"), "go")
if room_id:
return room_id
return None
def _guardian_talk_step_index(quest_chain: Any, *, guardian_id: Any) -> int | None:
if not isinstance(quest_chain, list) or not isinstance(guardian_id, str) or not guardian_id:
return None
for index, step in enumerate(quest_chain):
if not isinstance(step, dict):
continue
target_id = _extract_single_action_argument(step.get("action"), "talk")
if target_id == guardian_id:
return index
return None
def _extract_action_arguments(action: Any, name: str) -> list[str] | None:
if not isinstance(action, str):
return None
prefix = f"{name}("
if not action.startswith(prefix) or not action.endswith(")"):
return None
raw_arguments = action[len(prefix) : -1]
arguments = [argument.strip().strip('"').strip("'") for argument in raw_arguments.split(",")]
if any(not argument for argument in arguments):
return None
return arguments
def _insert_quest_step_before_index(
quest_chain: Any,
*,
index: int,
step_id_base: str,
description: str,
action: str,
allow_existing_action: bool = False,
) -> str | None:
if not isinstance(quest_chain, list) or index < 0 or index >= len(quest_chain):
return None
current_step = quest_chain[index]
if not isinstance(current_step, dict):
return None
if not allow_existing_action and any(isinstance(step, dict) and step.get("action") == action for step in quest_chain):
return None
existing_step_ids = {
step.get("step_id")
for step in quest_chain
if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id")
}
existing_requires = current_step.get("requires_step_ids")
if isinstance(existing_requires, list):
requires_step_ids = [step_id for step_id in existing_requires if isinstance(step_id, str) and step_id]
else:
requires_step_ids = []
new_step_id = _unique_world_id(step_id_base, existing_step_ids)
quest_chain.insert(
index,
{
"step_id": new_step_id,
"description": description,
"requires_step_ids": requires_step_ids,
"action": action,
},
)
current_step["requires_step_ids"] = [new_step_id]
return new_step_id
def _action_exists_before_index(quest_chain: Any, action: str, index: int) -> bool:
if not isinstance(quest_chain, list):
return False
for current_step in quest_chain[:index]:
if isinstance(current_step, dict) and current_step.get("action") == action:
return True
return False
def _quest_required_item_ids(action: Any) -> list[str]:
for name, count in (("use", 2), ("unlock", 2), ("give", 2), ("combine", 2)):
arguments = _extract_action_arguments(action, name)
if arguments is None:
continue
if name == "combine" and len(arguments) == count:
return arguments
if len(arguments) == count:
return [arguments[0 if name != "unlock" else 1]]
return []
def _infer_room_prereq_for_step(step: Any, step_by_id: dict[str, Any]) -> str | None:
if not isinstance(step, dict):
return None
step_id = step.get("step_id")
if isinstance(step_id, str) and step_id:
inferred_room = _infer_step_terminal_room(step_id, step_by_id, set())
if inferred_room is not None:
return inferred_room
requires_step_ids = step.get("requires_step_ids")
if not isinstance(requires_step_ids, list):
return None
room_id: str | None = None
for dependency in requires_step_ids:
if not isinstance(dependency, str):
continue
dependency_step = step_by_id.get(dependency)
if not isinstance(dependency_step, dict):
continue
maybe_room_id = _extract_single_action_argument(dependency_step.get("action"), "go")
if maybe_room_id:
room_id = maybe_room_id
return room_id
def _infer_step_terminal_room(step_id: str, step_by_id: dict[str, Any], seen: set[str]) -> str | None:
if step_id in seen:
return None
step = step_by_id.get(step_id)
if not isinstance(step, dict):
return None
seen = set(seen)
seen.add(step_id)
target_room = _extract_single_action_argument(step.get("action"), "go")
if target_room:
return target_room
requires_step_ids = step.get("requires_step_ids")
if not isinstance(requires_step_ids, list):
return None
inferred_room: str | None = None
for dependency in requires_step_ids:
if not isinstance(dependency, str):
continue
dependency_room = _infer_step_terminal_room(dependency, step_by_id, seen)
if dependency_room:
inferred_room = dependency_room
return inferred_room
def _reachable_passage_room_ids(payload: dict[str, Any], *, start_node_id: Any) -> set[str]:
if not isinstance(start_node_id, str) or not start_node_id:
return set()
edges = payload.get("edges")
if not isinstance(edges, list):
return {start_node_id}
graph: dict[str, set[str]] = {}
for edge in edges:
if not isinstance(edge, dict) or edge.get("type") != "passage":
continue
from_node_id = edge.get("from_node_id")
to_node_id = edge.get("to_node_id")
if not isinstance(from_node_id, str) or not isinstance(to_node_id, str):
continue
graph.setdefault(from_node_id, set()).add(to_node_id)
reachable = {start_node_id}
frontier = [start_node_id]
while frontier:
current = frontier.pop()
for nxt in graph.get(current, set()):
if nxt in reachable:
continue
reachable.add(nxt)
frontier.append(nxt)
return reachable
def _extract_single_action_argument(action: Any, name: str) -> str | None:
if not isinstance(action, str):
return None
prefix = f"{name}("
if not action.startswith(prefix) or not action.endswith(")"):
return None
raw_argument = action[len(prefix) : -1].strip()
if not raw_argument:
return None
if raw_argument[0] == raw_argument[-1] and raw_argument[0] in {'"', "'"}:
raw_argument = raw_argument[1:-1]
return raw_argument.strip()
def _load_dm_world_definition(text: str, *, allow_repair: bool) -> WorldDefinition:
payload = _try_parse_completion_json(text)
if not isinstance(payload, dict):
raise ValueError("Completion did not contain a JSON object.")
if allow_repair:
payload = _repair_dm_world_payload(payload)
try:
return WorldDefinition.model_validate(payload)
except Exception:
raise
def _find_json_object_span(text: str) -> tuple[int, int] | None:
start: int | None = None
depth = 0
in_string = False
escaped = False
for index, character in enumerate(text):
if in_string:
if escaped:
escaped = False
elif character == "\\":
escaped = True
elif character == '"':
in_string = False
continue
if character == '"':
in_string = True
continue
if character == "{":
if start is None:
start = index
depth += 1
continue
if character == "}":
if depth == 0:
continue
depth -= 1
if depth == 0 and start is not None:
return start, index + 1
return None
def _strip_code_fences(text: str) -> str:
cleaned = text.strip()
if cleaned.startswith("```"):
lines = cleaned.splitlines()
if lines and lines[0].startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
cleaned = "\n".join(lines).strip()
return cleaned
def _normalize_outer_completion_text(text: str) -> str:
without_tools = _TOOL_CALL_RE.sub("", text)
without_tools = _EMPTY_THINK_RE.sub("", without_tools)
without_tools = _strip_code_fences(without_tools)
return without_tools.strip()
def _string_key_coverage(value: Any, keys: tuple[str, ...]) -> float:
if not isinstance(value, dict):
return 0.0
return sum(1 for key in keys if key in value) / len(keys)
def _range_score(value: int, lower: int, upper: int) -> float:
if lower <= value <= upper:
return 1.0
if value < lower:
return max(0.0, value / max(1, lower))
return max(0.0, 1.0 - ((value - upper) / max(1, upper)))
def _compactness_score(length: int, target_max: int) -> float:
if length <= target_max:
return 1.0
overflow = length - target_max
return max(0.0, 1.0 - (overflow / max(1, target_max)))
def _dm_structural_prior_score(world: dict[str, Any], requested_ratio: float | None) -> float:
meta = world.get("meta")
nodes = world.get("nodes") if isinstance(world.get("nodes"), list) else []
edges = world.get("edges") if isinstance(world.get("edges"), list) else []
items = world.get("items") if isinstance(world.get("items"), list) else []
clues = world.get("clues") if isinstance(world.get("clues"), list) else []
recipes = world.get("recipes") if isinstance(world.get("recipes"), list) else []
quest_chain = world.get("quest_chain") if isinstance(world.get("quest_chain"), list) else []
components = [
(0.16, _string_key_coverage(world, _DM_REQUIRED_TOP_LEVEL_FIELDS)),
(0.08, _string_key_coverage(meta, ("title", "difficulty_target", "start_node_id", "win_condition"))),
(0.10, _dm_win_condition_score(meta)),
(0.10, _range_score(len(nodes), 10, 16)),
(0.07, _range_score(len(items), 5, 8)),
(0.09, _range_score(len(clues), 3, 5)),
(0.04, _range_score(len(recipes), 0, 1)),
(0.10, _range_score(len(quest_chain), 12, 20)),
(0.06, _valid_type_fraction(nodes, "type", _DM_ALLOWED_NODE_TYPES)),
(0.04, _valid_type_fraction(edges, "type", _DM_ALLOWED_EDGE_TYPES)),
(0.04, _valid_type_fraction(items, "subtype", _DM_ALLOWED_ITEM_TYPES)),
(0.06, _compact_world_text_score(nodes, items, clues, quest_chain)),
(0.06, _guardian_presence_score(meta, nodes)),
]
if requested_ratio is not None:
components.append((0.10, _difficulty_ratio_score(meta, requested_ratio)))
weighted_total = sum(weight * score for weight, score in components)
total_weight = sum(weight for weight, _ in components)
return _clamp(weighted_total / max(1e-6, total_weight), 0.0, 1.0)
def _dm_win_condition_score(meta: Any) -> float:
if not isinstance(meta, dict):
return 0.0
win_condition = meta.get("win_condition")
if not isinstance(win_condition, dict):
return 0.0
score = _string_key_coverage(win_condition, ("type", "target_npc_id", "answer_string"))
if win_condition.get("type") == "deduce":
score += 0.25
answer = win_condition.get("answer_string")
if isinstance(answer, str) and _LOWERCASE_ANSWER_RE.fullmatch(answer):
score += 0.25
return min(1.0, score)
def _guardian_presence_score(meta: Any, nodes: list[Any]) -> float:
if not isinstance(meta, dict):
return 0.0
win_condition = meta.get("win_condition")
if not isinstance(win_condition, dict):
return 0.0
guardian_id = win_condition.get("target_npc_id")
if not isinstance(guardian_id, str):
return 0.0
return 1.0 if any(isinstance(node, dict) and node.get("type") == "npc" and node.get("id") == guardian_id for node in nodes) else 0.0
def _difficulty_ratio_score(meta: Any, requested_ratio: float) -> float:
if not isinstance(meta, dict):
return 0.0
try:
actual_ratio = float(meta.get("difficulty_target"))
except Exception:
return 0.0
return max(0.0, 1.0 - abs(actual_ratio - requested_ratio))
def _valid_type_fraction(rows: list[Any], key: str, allowed_values: set[str]) -> float:
typed_rows = [row for row in rows if isinstance(row, dict)]
if not typed_rows:
return 0.0
valid = sum(1 for row in typed_rows if row.get(key) in allowed_values)
return valid / len(typed_rows)
def _compact_world_text_score(
nodes: list[Any],
items: list[Any],
clues: list[Any],
quest_chain: list[Any],
) -> float:
text_lengths: list[int] = []
for collection, keys in (
(nodes, ("label", "description")),
(items, ("label", "description")),
(clues, ("text",)),
(quest_chain, ("description", "action")),
):
for row in collection:
if not isinstance(row, dict):
continue
for key in keys:
value = row.get(key)
if isinstance(value, str):
text_lengths.append(len(value))
if not text_lengths:
return 0.0
average_length = sum(text_lengths) / len(text_lengths)
return _compactness_score(int(average_length), 80)
def _validation_error_score(errors: list[dict[str, Any]]) -> float:
if not errors:
return 0.0
penalty = 0.0
for error in errors:
error_type = str(error.get("type", ""))
location = tuple(str(part) for part in error.get("loc", ()))
field_name = location[-1] if location else ""
if error_type == "extra_forbidden":
penalty += 0.05
elif error_type.startswith("missing") and field_name in {"label", "description"}:
penalty += 0.02
elif error_type.startswith("missing") and field_name == "text_content":
penalty += 0.05
elif error_type.startswith("missing"):
penalty += 0.06
else:
penalty += 0.08
return _clamp(1.0 - penalty, 0.0, 1.0)
def _compile_error_penalty(error_message: str) -> float:
message = error_message.lower()
if not message:
return -0.5
if "between 3 and 5 clues" in message:
return -0.35
if "duplicate world id" in message or "duplicate " in message:
return -0.45
if "requires_step_id" in message or "requires_step_with" in message:
return -0.45
if "requires requires_item_id" in message:
return -0.50
if "must live in a location or junction" in message:
return -0.55
if "fixture" in message and "requires unknown item" in message:
return -0.60
if "unknown item" in message or "unknown clue" in message or "unknown node" in message:
return -0.65
if "must reveal exactly one item or readable" in message:
return -0.65
if "guardian npc cannot have trade fields" in message:
return -0.70
if "unused decorative items" in message or "clue '" in message:
return -0.75
if "final quest step" in message or "penultimate quest step" in message:
return -0.80
if "unreachable" in message or "guardian room" in message:
return -0.85
if "closed door" in message or "locked door" in message or "does not match key" in message:
return -0.85
if "quest " in message or "unsupported quest action" in message:
return -0.90
return -0.75
def _completion_tool_calls(completion: Any) -> list[dict[str, Any]]:
return _extract_tool_calls_from_text(_completion_text(completion))
def _extract_tool_calls_from_text(text: str) -> list[dict[str, Any]]:
tool_calls: list[dict[str, Any]] = []
for raw_payload in _TOOL_CALL_RE.findall(text):
try:
payload = json.loads(raw_payload)
except Exception:
continue
normalized = _normalize_tool_call(payload, source="tool_call")
if normalized is not None:
tool_calls.append(normalized)
if tool_calls:
return tool_calls
payload = _try_parse_completion_json(text)
normalized = _normalize_tool_call(payload, source="json_action")
if normalized is None:
return []
return [normalized]
def _normalize_tool_call(payload: Any, *, source: str) -> dict[str, Any] | None:
if not isinstance(payload, dict):
return None
if payload.get("type") == "function" and isinstance(payload.get("function"), dict):
payload = payload["function"]
if isinstance(payload.get("name"), str):
arguments = payload.get("arguments", {})
if not isinstance(arguments, dict):
return None
return {"name": payload["name"], "arguments": arguments, "source": source}
action = payload.get("action")
if isinstance(action, dict) and isinstance(action.get("tool"), str):
arguments = {key: value for key, value in action.items() if key != "tool"}
return {"name": action["tool"], "arguments": arguments, "source": source}
return None
def _hero_act_semantics_reward(arguments: Any) -> float:
if not isinstance(arguments, dict):
return 0.0
command = arguments.get("command")
if not isinstance(command, str) or not command.strip():
return 0.0
normalized_command = command.strip().lower()
parsed = parse_cli_command(command)
if not parsed.valid:
recovered = parse_cli_command(normalized_command)
return 0.40 if recovered.valid else 0.0
return 1.0 if command == normalized_command else 0.85
def _hero_scratchpad_write_reward(arguments: Any) -> float:
if not isinstance(arguments, dict):
return 0.0
mode = arguments.get("mode")
content = arguments.get("content")
score = 0.0
if mode in {"append", "replace"}:
score += 0.45
if isinstance(content, str) and content.strip():
score += 0.35
score += 0.20 * _compactness_score(len(content), 240)
return min(1.0, score)
def _clamp(value: float, lower: float, upper: float) -> float:
return max(lower, min(upper, value))
def _require_training_dependencies() -> None:
if TRAINING_IMPORT_ERROR is not None:
raise RuntimeError(
"Training dependencies are unavailable. Install the project with the training extras before using GRPO."
) from TRAINING_IMPORT_ERROR
def _require_vllm_if_requested(config: GRPOLaunchConfig) -> None:
if not config.use_vllm:
return
if importlib.util.find_spec("vllm") is None:
raise RuntimeError(
"vLLM is not installed but --use-vllm was requested. Install vllm in the training environment first."
)