Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import hashlib | |
| import importlib.util | |
| import json | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Callable | |
| from agents.hero.cli import parse_cli_command | |
| from agents.hero.env import HeroEnvironment | |
| from agents.hero.policy import HeroLLMPolicy | |
| from agents.hero.runner import HeroRunner | |
| from agents.master.base import normalize_answer_text, parser_safe_text | |
| from agents.master.check import validate_and_normalize | |
| from agents.hero.prompt import format_hero_grpo_system_prompt | |
| from agents.hero.schema import validate_hero_action | |
| from agents.master.env import DMEnvironment | |
| from agents.master.prompt import build_dm_world_messages | |
| from agents.master.sample import load_world, sample_world_definition | |
| from agents.master.schema import WorldDefinition | |
| from agents.shared.runtime import ( | |
| build_interface_adapter, | |
| create_structured_client, | |
| resolve_interface_config, | |
| resolve_structured_client_config, | |
| ) | |
| try: | |
| import torch | |
| from datasets import Dataset | |
| from peft import LoraConfig | |
| from trl.chat_template_utils import qwen3_chat_template, qwen3_schema | |
| from trl.rewards import get_soft_overlong_punishment | |
| from trl import GRPOConfig, GRPOTrainer | |
| from transformers import AutoTokenizer, BitsAndBytesConfig | |
| TRAINING_IMPORT_ERROR: Exception | None = None | |
| except Exception as exc: # pragma: no cover - exercised when train extras are unavailable | |
| torch = None # type: ignore[assignment] | |
| Dataset = None # type: ignore[assignment] | |
| LoraConfig = None # type: ignore[assignment] | |
| GRPOConfig = None # type: ignore[assignment] | |
| GRPOTrainer = None # type: ignore[assignment] | |
| AutoTokenizer = None # type: ignore[assignment] | |
| BitsAndBytesConfig = None # type: ignore[assignment] | |
| qwen3_chat_template = None # type: ignore[assignment] | |
| qwen3_schema = None # type: ignore[assignment] | |
| get_soft_overlong_punishment = None # type: ignore[assignment] | |
| TRAINING_IMPORT_ERROR = exc | |
| _DEFAULT_TARGET_RATIOS = [1.25, 1.5, 1.75, 2.0] | |
| _DM_REQUIRED_TOP_LEVEL_FIELDS = ("meta", "nodes", "edges", "items", "clues", "recipes", "quest_chain") | |
| _DM_ALLOWED_NODE_TYPES = {"location", "junction", "container", "door", "readable", "fixture", "npc"} | |
| _DM_ALLOWED_EDGE_TYPES = {"passage", "locked_passage"} | |
| _DM_ALLOWED_ITEM_TYPES = {"key", "puzzle"} | |
| _HERO_TOOL_NAMES = {"act", "scratchpad_read", "scratchpad_write"} | |
| _TOOL_CALL_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL) | |
| _EMPTY_THINK_RE = re.compile(r"<think>\s*</think>\s*", re.DOTALL) | |
| _LOWERCASE_ANSWER_RE = re.compile(r"^[a-z0-9]+(?: [a-z0-9]+)*$") | |
| _HERO_TASK_PROMPTS = ( | |
| "Solve the dungeon by using tools until the episode ends.\nInitial observation:\n", | |
| "Play the dungeon to completion through tool calls only.\nInitial observation:\n", | |
| "Gather every clue and solve the dungeon via tools only.\nInitial observation:\n", | |
| ) | |
| SUPPORTED_GRPO_LOSS_TYPES = ("grpo", "dapo", "bnpo", "dr_grpo", "cispo", "sapo", "luspo") | |
| SUPPORTED_IMPORTANCE_SAMPLING_LEVELS = ("token", "sequence") | |
| class GRPOLaunchConfig: | |
| model_name: str | |
| output_dir: Path | |
| resume_adapter_path: str | None = None | |
| max_steps: int = 10 | |
| num_prompts: int = 16 | |
| learning_rate: float = 1e-5 | |
| per_device_train_batch_size: int = 2 | |
| gradient_accumulation_steps: int = 8 | |
| num_generations: int = 2 | |
| max_completion_length: int = 512 | |
| logging_steps: int = 1 | |
| save_steps: int = 10 | |
| seed: int = 42 | |
| rank: int = 16 | |
| alpha: int = 32 | |
| dropout: float = 0.05 | |
| temperature: float = 0.6 | |
| top_p: float = 0.95 | |
| top_k: int = 20 | |
| min_p: float | None = None | |
| repetition_penalty: float = 1.0 | |
| use_wandb: bool = True | |
| run_name: str | None = None | |
| trust_remote_code: bool = False | |
| load_in_4bit: bool = True | |
| loss_type: str = "dapo" | |
| importance_sampling_level: str = "token" | |
| use_transformers_paged: bool = False | |
| cache_implementation: str | None = None | |
| use_vllm: bool = False | |
| vllm_mode: str = "colocate" | |
| vllm_gpu_memory_utilization: float = 0.2 | |
| vllm_enable_sleep_mode: bool = True | |
| def __post_init__(self) -> None: | |
| if self.loss_type not in SUPPORTED_GRPO_LOSS_TYPES: | |
| raise ValueError( | |
| f"loss_type must be one of {SUPPORTED_GRPO_LOSS_TYPES!r}; got {self.loss_type!r}." | |
| ) | |
| if self.importance_sampling_level not in SUPPORTED_IMPORTANCE_SAMPLING_LEVELS: | |
| raise ValueError( | |
| "importance_sampling_level must be one of " | |
| f"{SUPPORTED_IMPORTANCE_SAMPLING_LEVELS!r}; got {self.importance_sampling_level!r}." | |
| ) | |
| if self.loss_type == "luspo" and self.importance_sampling_level != "sequence": | |
| raise ValueError("luspo requires importance_sampling_level='sequence'.") | |
| if self.per_device_train_batch_size < 1: | |
| raise ValueError("per_device_train_batch_size must be at least 1.") | |
| if self.gradient_accumulation_steps < 1: | |
| raise ValueError("gradient_accumulation_steps must be at least 1.") | |
| if self.num_generations < 2: | |
| raise ValueError("num_generations must be at least 2 for GRPO.") | |
| if self.max_steps < 1: | |
| raise ValueError("max_steps must be at least 1.") | |
| if self.num_prompts < 1: | |
| raise ValueError("num_prompts must be at least 1.") | |
| if self.temperature <= 0.0: | |
| raise ValueError("temperature must be greater than 0.") | |
| if not 0.0 < self.top_p <= 1.0: | |
| raise ValueError("top_p must be in the interval (0, 1].") | |
| if self.top_k < 0: | |
| raise ValueError("top_k must be non-negative.") | |
| if self.min_p is not None and not 0.0 <= self.min_p <= 1.0: | |
| raise ValueError("min_p must be in the interval [0, 1] when provided.") | |
| if self.repetition_penalty < 1.0: | |
| raise ValueError("repetition_penalty must be at least 1.0.") | |
| if self.vllm_mode not in {"server", "colocate"}: | |
| raise ValueError("vllm_mode must be 'server' or 'colocate'.") | |
| if not 0.0 < self.vllm_gpu_memory_utilization < 1.0: | |
| raise ValueError("vllm_gpu_memory_utilization must be in the interval (0, 1).") | |
| world_size = max(1, int(os.environ.get("WORLD_SIZE", "1"))) | |
| generation_batch_size = self.per_device_train_batch_size * world_size | |
| if generation_batch_size % self.num_generations != 0: | |
| raise ValueError( | |
| "generation_batch_size " | |
| f"({generation_batch_size}) must be divisible by num_generations ({self.num_generations}). " | |
| "Increase --per-device-train-batch-size, reduce --num-generations, or launch with more processes." | |
| ) | |
| minimum_prompt_rows = generation_batch_size * self.gradient_accumulation_steps | |
| if self.num_prompts < minimum_prompt_rows: | |
| raise ValueError( | |
| "num_prompts " | |
| f"({self.num_prompts}) must be at least generation_batch_size * gradient_accumulation_steps " | |
| f"({minimum_prompt_rows}) so GRPO can complete one optimizer step." | |
| ) | |
| class DMClosedLoopConfig: | |
| hero_provider: str | None = None | |
| hero_model: str | None = None | |
| hero_adapter_path: str | None = None | |
| interface_provider: str | None = None | |
| interface_model: str | None = None | |
| interface_narrate: bool = False | |
| interface_translation_mode: str | None = None | |
| hero_max_game_steps: int = 40 | |
| hero_max_tool_calls: int = 80 | |
| class DMRolloutMetrics: | |
| reward: float | |
| compile_error: str | None | |
| requested_ratio: float | |
| player_won: bool | |
| steps_taken: int | None | |
| min_steps: int | None | |
| ratio: float | None | |
| efficiency_score: float | |
| quality_score: float | |
| invalid_command_count: int | |
| wrong_submit_count: int | |
| hero_player_won: bool | |
| hero_total_reward: float | |
| hero_dense_return: float | |
| hero_steps_taken: int | |
| hero_tool_calls_total: int | |
| hero_policy_error: str | None | |
| _DM_ROLLOUT_CACHE_STEP = -1 | |
| _DM_ROLLOUT_CACHE: dict[tuple[Any, ...], DMRolloutMetrics] = {} | |
| def build_dm_grpo_dataset( | |
| *, | |
| num_prompts: int = 8, | |
| target_ratios: list[float] | None = None, | |
| ) -> list[dict[str, Any]]: | |
| ratios = target_ratios or _DEFAULT_TARGET_RATIOS | |
| rows: list[dict[str, Any]] = [] | |
| for index in range(num_prompts): | |
| target_ratio = ratios[index % len(ratios)] | |
| reference_world = sample_world_definition(seed=index, difficulty_target=target_ratio) | |
| prompt = [ | |
| {"role": message.role, "content": message.content} | |
| for message in build_dm_world_messages( | |
| target_ratio=target_ratio, | |
| reference_world=reference_world, | |
| prompt_style=index, | |
| ) | |
| ] | |
| rows.append({"prompt": prompt, "target_ratio": target_ratio, "seed": index}) | |
| return rows | |
| def build_hero_grpo_dataset( | |
| *, | |
| num_prompts: int = 8, | |
| world_input: dict[str, Any] | None = None, | |
| target_ratios: list[float] | None = None, | |
| max_game_steps: int = 40, | |
| max_tool_calls: int = 80, | |
| ) -> list[dict[str, Any]]: | |
| ratios = target_ratios or _DEFAULT_TARGET_RATIOS | |
| rows: list[dict[str, Any]] = [] | |
| for index in range(num_prompts): | |
| target_ratio = ratios[index % len(ratios)] | |
| world = world_input or sample_world_definition(seed=index, difficulty_target=target_ratio) | |
| world_title = str(world["meta"]["title"]) | |
| prompt = [ | |
| { | |
| "role": "system", | |
| "content": format_hero_grpo_system_prompt(world_title, max_game_steps, max_tool_calls), | |
| }, | |
| { | |
| "role": "user", | |
| "content": _HERO_TASK_PROMPTS[index % len(_HERO_TASK_PROMPTS)], | |
| }, | |
| ] | |
| rows.append( | |
| { | |
| "prompt": prompt, | |
| "world_definition_json": json.dumps(world, separators=(",", ":")), | |
| "seed": index, | |
| "target_ratio": target_ratio, | |
| "max_game_steps": max_game_steps, | |
| "max_tool_calls": max_tool_calls, | |
| } | |
| ) | |
| return rows | |
| class HeroToolEnvironment: | |
| def __init__( | |
| self, | |
| *, | |
| artifacts_root: Path | None = None, | |
| interface_provider: str | None = None, | |
| interface_model: str | None = None, | |
| interface_narrate: bool = False, | |
| interface_translation_mode: str | None = None, | |
| ) -> None: | |
| self.artifacts_root = artifacts_root | |
| self.interface_provider = interface_provider | |
| self.interface_model = interface_model | |
| self.interface_narrate = interface_narrate | |
| self.interface_translation_mode = interface_translation_mode | |
| self.hero_env: HeroEnvironment | None = None | |
| self.last_message = "" | |
| def reset( | |
| self, | |
| *, | |
| world_definition_json: str, | |
| seed: int | None = None, | |
| max_game_steps: int = 40, | |
| max_tool_calls: int = 80, | |
| prompt: Any | None = None, | |
| **_: Any, | |
| ) -> str: | |
| del prompt | |
| interface_adapter = build_interface_adapter( | |
| resolve_interface_config( | |
| provider=self.interface_provider, # type: ignore[arg-type] | |
| model_name=self.interface_model, | |
| narrate_observations=self.interface_narrate, | |
| translation_mode=self.interface_translation_mode, # type: ignore[arg-type] | |
| ) | |
| ) | |
| self.hero_env = HeroEnvironment( | |
| artifacts_root=self.artifacts_root, | |
| interface_adapter=interface_adapter, | |
| ) | |
| observation = self.hero_env.reset( | |
| world_input=json.loads(world_definition_json), | |
| seed=seed, | |
| max_game_steps=max_game_steps, | |
| max_tool_calls=max_tool_calls, | |
| ) | |
| self.last_message = observation.message | |
| return observation.message | |
| def act(self, command: str) -> str: | |
| """Act in the dungeon with one strict CLI command. | |
| Args: | |
| command: Lowercase parser-style dungeon command. | |
| Returns: | |
| The environment's next observation message. | |
| """ | |
| return self._step({"tool": "act", "command": command}) | |
| def scratchpad_read(self) -> str: | |
| """Read the current scratchpad contents. | |
| Returns: | |
| The scratchpad text. | |
| """ | |
| return self._step({"tool": "scratchpad_read"}) | |
| def scratchpad_write(self, mode: str, content: str) -> str: | |
| """Write to the scratchpad. | |
| Args: | |
| mode: Either append or replace. | |
| content: Text to write. | |
| Returns: | |
| The environment's acknowledgement message. | |
| """ | |
| return self._step({"tool": "scratchpad_write", "mode": mode, "content": content}) | |
| def _cumulative_reward(self) -> float: | |
| if self.hero_env is None: | |
| return -1.0 | |
| return float(self.hero_env.episode_stats.total_reward) | |
| def _episode_done(self) -> bool: | |
| if self.hero_env is None or self.hero_env.session is None: | |
| return False | |
| return bool(self.hero_env.session.done or self.hero_env.state.status in {"won", "lost", "timed_out"}) | |
| def _episode_won(self) -> bool: | |
| if self.hero_env is None: | |
| return False | |
| return bool(self.hero_env.episode_stats.player_won) | |
| def _step(self, action: dict[str, Any]) -> str: | |
| if self.hero_env is None: | |
| raise RuntimeError("HeroToolEnvironment.reset must be called before using tools.") | |
| result = self.hero_env.step(action) | |
| self.last_message = result.observation.message | |
| return result.observation.message | |
| def dm_reward_function( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| target_ratio: list[float], | |
| trainer_state: Any, | |
| hero_policy: Any, | |
| interface_provider: str | None = None, | |
| interface_model: str | None = None, | |
| interface_narrate: bool = False, | |
| interface_translation_mode: str | None = None, | |
| hero_max_game_steps: int = 40, | |
| hero_max_tool_calls: int = 80, | |
| artifacts_root: str | None = None, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts | |
| rewards: list[float] = [] | |
| for index, (completion, requested_ratio) in enumerate(zip(completions, target_ratio, strict=True)): | |
| metrics = _cached_dm_rollout_metrics( | |
| completion=completion, | |
| requested_ratio=requested_ratio, | |
| trainer_state=trainer_state, | |
| completion_index=index, | |
| hero_policy=hero_policy, | |
| interface_provider=interface_provider, | |
| interface_model=interface_model, | |
| interface_narrate=interface_narrate, | |
| interface_translation_mode=interface_translation_mode, | |
| hero_max_game_steps=hero_max_game_steps, | |
| hero_max_tool_calls=hero_max_tool_calls, | |
| artifacts_root=artifacts_root, | |
| ) | |
| if metrics.compile_error is not None: | |
| rewards.append(_compile_error_penalty(metrics.compile_error)) | |
| continue | |
| rewards.append(metrics.reward) | |
| return rewards | |
| def _dm_reward_artifacts_dir( | |
| *, | |
| artifacts_root: str | None, | |
| trainer_state: Any, | |
| completion_index: int, | |
| ) -> Path | None: | |
| if artifacts_root is None: | |
| return None | |
| step = getattr(trainer_state, "global_step", 0) | |
| return Path(artifacts_root) / "dm_reward_rollouts" / f"step_{step:05d}" / f"sample_{completion_index:02d}" | |
| def dm_hero_success_reward( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| target_ratio: list[float], | |
| trainer_state: Any, | |
| hero_policy: Any, | |
| interface_provider: str | None = None, | |
| interface_model: str | None = None, | |
| interface_narrate: bool = False, | |
| interface_translation_mode: str | None = None, | |
| hero_max_game_steps: int = 40, | |
| hero_max_tool_calls: int = 80, | |
| artifacts_root: str | None = None, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts | |
| rewards: list[float] = [] | |
| for index, (completion, requested_ratio) in enumerate(zip(completions, target_ratio, strict=True)): | |
| metrics = _cached_dm_rollout_metrics( | |
| completion=completion, | |
| requested_ratio=requested_ratio, | |
| trainer_state=trainer_state, | |
| completion_index=index, | |
| hero_policy=hero_policy, | |
| interface_provider=interface_provider, | |
| interface_model=interface_model, | |
| interface_narrate=interface_narrate, | |
| interface_translation_mode=interface_translation_mode, | |
| hero_max_game_steps=hero_max_game_steps, | |
| hero_max_tool_calls=hero_max_tool_calls, | |
| artifacts_root=artifacts_root, | |
| ) | |
| if metrics.compile_error is not None: | |
| rewards.append(_compile_error_penalty(metrics.compile_error)) | |
| continue | |
| rewards.append(float(metrics.hero_player_won)) | |
| return rewards | |
| def dm_hero_efficiency_reward( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| target_ratio: list[float], | |
| trainer_state: Any, | |
| hero_policy: Any, | |
| interface_provider: str | None = None, | |
| interface_model: str | None = None, | |
| interface_narrate: bool = False, | |
| interface_translation_mode: str | None = None, | |
| hero_max_game_steps: int = 40, | |
| hero_max_tool_calls: int = 80, | |
| artifacts_root: str | None = None, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts | |
| rewards: list[float] = [] | |
| for index, (completion, requested_ratio) in enumerate(zip(completions, target_ratio, strict=True)): | |
| metrics = _cached_dm_rollout_metrics( | |
| completion=completion, | |
| requested_ratio=requested_ratio, | |
| trainer_state=trainer_state, | |
| completion_index=index, | |
| hero_policy=hero_policy, | |
| interface_provider=interface_provider, | |
| interface_model=interface_model, | |
| interface_narrate=interface_narrate, | |
| interface_translation_mode=interface_translation_mode, | |
| hero_max_game_steps=hero_max_game_steps, | |
| hero_max_tool_calls=hero_max_tool_calls, | |
| artifacts_root=artifacts_root, | |
| ) | |
| if metrics.compile_error is not None: | |
| rewards.append(_compile_error_penalty(metrics.compile_error)) | |
| continue | |
| if not metrics.hero_player_won: | |
| rewards.append(0.0) | |
| continue | |
| rewards.append(_clamp(metrics.efficiency_score, 0.0, 1.0)) | |
| return rewards | |
| def dm_hero_cleanliness_reward( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| target_ratio: list[float], | |
| trainer_state: Any, | |
| hero_policy: Any, | |
| interface_provider: str | None = None, | |
| interface_model: str | None = None, | |
| interface_narrate: bool = False, | |
| interface_translation_mode: str | None = None, | |
| hero_max_game_steps: int = 40, | |
| hero_max_tool_calls: int = 80, | |
| artifacts_root: str | None = None, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts | |
| rewards: list[float] = [] | |
| for index, (completion, requested_ratio) in enumerate(zip(completions, target_ratio, strict=True)): | |
| metrics = _cached_dm_rollout_metrics( | |
| completion=completion, | |
| requested_ratio=requested_ratio, | |
| trainer_state=trainer_state, | |
| completion_index=index, | |
| hero_policy=hero_policy, | |
| interface_provider=interface_provider, | |
| interface_model=interface_model, | |
| interface_narrate=interface_narrate, | |
| interface_translation_mode=interface_translation_mode, | |
| hero_max_game_steps=hero_max_game_steps, | |
| hero_max_tool_calls=hero_max_tool_calls, | |
| artifacts_root=artifacts_root, | |
| ) | |
| if metrics.compile_error is not None: | |
| rewards.append(_compile_error_penalty(metrics.compile_error)) | |
| continue | |
| step_budget = max(1, metrics.hero_steps_taken or metrics.steps_taken or 0) | |
| penalty = (metrics.invalid_command_count + (2 * metrics.wrong_submit_count)) / step_budget | |
| score = max(0.0, 1.0 - penalty) | |
| if metrics.hero_policy_error is not None: | |
| score = min(score, 0.25) | |
| rewards.append(_clamp(score, 0.0, 1.0)) | |
| return rewards | |
| def dm_json_format_reward( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| trainer_state: Any, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts, trainer_state | |
| rewards: list[float] = [] | |
| for completion in completions: | |
| text = _completion_text(completion) | |
| score = 0.0 | |
| json_text, leading_text, trailing_text = _extract_json_candidate_parts(text) | |
| if json_text is None: | |
| if "{" in text: | |
| score += 0.05 | |
| if "<think>" in text: | |
| score -= 0.10 | |
| rewards.append(_clamp(score, -0.25, 1.0)) | |
| continue | |
| try: | |
| json.loads(json_text) | |
| score += 0.60 | |
| except Exception: | |
| score += 0.20 | |
| outer_text = (leading_text + trailing_text).strip() | |
| if not outer_text: | |
| score += 0.25 | |
| else: | |
| ratio = len(json_text) / max(1, len(_strip_code_fences(text).strip())) | |
| score += 0.15 * ratio | |
| score += 0.10 * _compactness_score(len(json_text), 4500) | |
| if "<think>" in text: | |
| score -= 0.15 | |
| if "```" in text: | |
| score -= 0.05 | |
| rewards.append(_clamp(score, -0.25, 1.0)) | |
| return rewards | |
| def dm_schema_reward( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| target_ratio: list[float] | None = None, | |
| trainer_state: Any, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts, trainer_state | |
| target_ratio = target_ratio or [None] * len(completions) | |
| rewards: list[float] = [] | |
| for completion, requested_ratio in zip(completions, target_ratio, strict=True): | |
| payload = _try_parse_completion_json(_completion_text(completion)) | |
| if not isinstance(payload, dict): | |
| rewards.append(0.0) | |
| continue | |
| rewards.append(_dm_structural_prior_score(payload, requested_ratio)) | |
| return rewards | |
| def dm_validation_reward( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| trainer_state: Any, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts, trainer_state | |
| rewards: list[float] = [] | |
| for completion in completions: | |
| payload = _try_parse_completion_json(_completion_text(completion)) | |
| if not isinstance(payload, dict): | |
| rewards.append(0.0) | |
| continue | |
| try: | |
| WorldDefinition.model_validate(payload) | |
| rewards.append(1.0) | |
| except Exception as exc: | |
| error_list = exc.errors() if hasattr(exc, "errors") else [] | |
| rewards.append(_validation_error_score(error_list)) | |
| return rewards | |
| def dm_compile_prior_reward( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| trainer_state: Any, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts, trainer_state | |
| rewards: list[float] = [] | |
| for completion in completions: | |
| try: | |
| world = _load_dm_world_definition(_completion_text(completion), allow_repair=True) | |
| except Exception as exc: | |
| rewards.append(_compile_error_penalty(str(exc))) | |
| continue | |
| try: | |
| validate_and_normalize(world) | |
| rewards.append(1.0) | |
| except Exception as exc: | |
| rewards.append(_compile_error_penalty(str(exc))) | |
| return rewards | |
| def _bind_dm_reward_function( | |
| *, | |
| artifacts_root: str | None, | |
| hero_policy: Any, | |
| interface_provider: str | None, | |
| interface_model: str | None, | |
| interface_narrate: bool, | |
| interface_translation_mode: str | None = None, | |
| hero_max_game_steps: int, | |
| hero_max_tool_calls: int, | |
| ) -> Any: | |
| return _bind_dm_rollout_reward( | |
| dm_reward_function, | |
| artifacts_root=artifacts_root, | |
| hero_policy=hero_policy, | |
| interface_provider=interface_provider, | |
| interface_model=interface_model, | |
| interface_narrate=interface_narrate, | |
| interface_translation_mode=interface_translation_mode, | |
| hero_max_game_steps=hero_max_game_steps, | |
| hero_max_tool_calls=hero_max_tool_calls, | |
| ) | |
| def _bind_dm_rollout_reward( | |
| reward_impl: Callable[..., list[float]], | |
| *, | |
| artifacts_root: str | None, | |
| hero_policy: Any, | |
| interface_provider: str | None, | |
| interface_model: str | None, | |
| interface_narrate: bool, | |
| interface_translation_mode: str | None = None, | |
| hero_max_game_steps: int, | |
| hero_max_tool_calls: int, | |
| ) -> Any: | |
| def reward_func(**kwargs: Any) -> list[float]: | |
| return reward_impl( | |
| artifacts_root=artifacts_root, | |
| hero_policy=hero_policy, | |
| interface_provider=interface_provider, | |
| interface_model=interface_model, | |
| interface_narrate=interface_narrate, | |
| interface_translation_mode=interface_translation_mode, | |
| hero_max_game_steps=hero_max_game_steps, | |
| hero_max_tool_calls=hero_max_tool_calls, | |
| **kwargs, | |
| ) | |
| reward_func.__name__ = reward_impl.__name__ | |
| return reward_func | |
| def _make_named_overlong_reward(*, name: str, max_completion_len: int) -> Callable[..., list[float]] | None: | |
| if get_soft_overlong_punishment is None: | |
| return None | |
| soft_punish_cache = max(16, min(64, max_completion_len // 4)) | |
| reward_func = get_soft_overlong_punishment(max_completion_len=max_completion_len, soft_punish_cache=soft_punish_cache) | |
| reward_func.__name__ = name | |
| return reward_func | |
| def _canonicalize_qwen_chat_template(tokenizer: Any) -> Any: | |
| chat_template = getattr(tokenizer, "chat_template", "") or "" | |
| if qwen3_chat_template is None: | |
| return tokenizer | |
| if "<|im_start|>" not in chat_template or "<|im_end|>" not in chat_template: | |
| return tokenizer | |
| tokenizer.chat_template = qwen3_chat_template | |
| return tokenizer | |
| def _chat_template_kwargs(tokenizer: Any) -> dict[str, Any] | None: | |
| if not hasattr(tokenizer, "apply_chat_template"): | |
| return None | |
| try: | |
| tokenizer.apply_chat_template( | |
| [{"role": "user", "content": "ping"}], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| except Exception: | |
| return None | |
| return {"enable_thinking": False} | |
| def _ensure_tool_response_schema(tokenizer: Any) -> Any: | |
| tokenizer = _canonicalize_qwen_chat_template(tokenizer) | |
| chat_template = getattr(tokenizer, "chat_template", "") or "" | |
| if qwen3_chat_template is None or qwen3_schema is None: | |
| return tokenizer | |
| if not hasattr(tokenizer, "parse_response"): | |
| return tokenizer | |
| if "<tool_call>" not in chat_template or "<|im_start|>" not in chat_template: | |
| return tokenizer | |
| tokenizer.chat_template = qwen3_chat_template | |
| if getattr(tokenizer, "response_schema", None) is not None: | |
| return tokenizer | |
| tokenizer.response_schema = qwen3_schema | |
| return tokenizer | |
| def hero_tool_format_reward( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| trainer_state: Any, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts, trainer_state | |
| rewards: list[float] = [] | |
| for completion in completions: | |
| text = _completion_text(completion) | |
| tool_calls = _completion_tool_calls(completion) | |
| score = 0.0 | |
| if len(tool_calls) == 1: | |
| call = tool_calls[0] | |
| score += 0.65 if call["source"] == "tool_call" else 0.30 | |
| if call["name"] in _HERO_TOOL_NAMES: | |
| score += 0.15 | |
| outer_text = _normalize_outer_completion_text(text) | |
| if not outer_text: | |
| score += 0.15 | |
| else: | |
| score += 0.10 * (1.0 - min(1.0, len(outer_text) / max(1, len(text.strip())))) | |
| elif len(tool_calls) > 1: | |
| score += 0.20 | |
| if all(call["name"] in _HERO_TOOL_NAMES for call in tool_calls): | |
| score += 0.05 | |
| else: | |
| if "<tool_call>" in text: | |
| score += 0.05 | |
| elif '{"action"' in text.replace(" ", ""): | |
| score += 0.10 | |
| if "<think>" in text: | |
| score -= 0.15 | |
| if "```" in text: | |
| score -= 0.05 | |
| rewards.append(_clamp(score, -0.25, 1.0)) | |
| return rewards | |
| def hero_action_semantics_reward( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| trainer_state: Any, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts, trainer_state | |
| rewards: list[float] = [] | |
| for completion in completions: | |
| tool_calls = _completion_tool_calls(completion) | |
| if len(tool_calls) != 1: | |
| rewards.append(0.10 if len(tool_calls) > 1 else 0.0) | |
| continue | |
| tool_call = tool_calls[0] | |
| tool_name = tool_call["name"] | |
| arguments = tool_call["arguments"] | |
| if tool_name == "act": | |
| reward = _hero_act_semantics_reward(arguments) | |
| elif tool_name == "scratchpad_read": | |
| reward = 1.0 if not arguments else 0.80 | |
| elif tool_name == "scratchpad_write": | |
| reward = _hero_scratchpad_write_reward(arguments) | |
| else: | |
| reward = -0.25 | |
| if tool_call["source"] != "tool_call": | |
| reward *= 0.85 | |
| rewards.append(_clamp(reward, -0.25, 1.0)) | |
| return rewards | |
| def hero_reward_function( | |
| *, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| environments: list[HeroToolEnvironment], | |
| trainer_state: Any, | |
| **_: Any, | |
| ) -> list[float]: | |
| del prompts, completions, trainer_state | |
| rewards: list[float] = [] | |
| for environment in environments: | |
| reward = environment._cumulative_reward() | |
| if not environment._episode_done(): | |
| reward -= 0.05 | |
| rewards.append(reward) | |
| return rewards | |
| def create_dm_grpo_trainer( | |
| config: GRPOLaunchConfig, | |
| *, | |
| target_ratios: list[float] | None = None, | |
| artifacts_root: Path | None = None, | |
| closed_loop: DMClosedLoopConfig | None = None, | |
| ): | |
| _require_training_dependencies() | |
| closed_loop = closed_loop or DMClosedLoopConfig() | |
| rows = build_dm_grpo_dataset(num_prompts=config.num_prompts, target_ratios=target_ratios) | |
| dataset = Dataset.from_list(rows) | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=config.trust_remote_code) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer = _canonicalize_qwen_chat_template(tokenizer) | |
| chat_template_kwargs = _chat_template_kwargs(tokenizer) | |
| hero_client_config = resolve_structured_client_config( | |
| "hero", | |
| provider=closed_loop.hero_provider, # type: ignore[arg-type] | |
| model_name=closed_loop.hero_model, | |
| adapter_path=closed_loop.hero_adapter_path, | |
| ) | |
| hero_policy = HeroLLMPolicy( | |
| create_structured_client(hero_client_config), | |
| model_name=hero_client_config.model_name, | |
| ) | |
| reward_funcs: list[Any] = [ | |
| dm_json_format_reward, | |
| dm_schema_reward, | |
| dm_validation_reward, | |
| dm_compile_prior_reward, | |
| _bind_dm_rollout_reward( | |
| dm_hero_success_reward, | |
| artifacts_root=str(artifacts_root) if artifacts_root is not None else None, | |
| hero_policy=hero_policy, | |
| interface_provider=closed_loop.interface_provider, | |
| interface_model=closed_loop.interface_model, | |
| interface_narrate=closed_loop.interface_narrate, | |
| interface_translation_mode=closed_loop.interface_translation_mode, | |
| hero_max_game_steps=closed_loop.hero_max_game_steps, | |
| hero_max_tool_calls=closed_loop.hero_max_tool_calls, | |
| ), | |
| _bind_dm_rollout_reward( | |
| dm_hero_efficiency_reward, | |
| artifacts_root=str(artifacts_root) if artifacts_root is not None else None, | |
| hero_policy=hero_policy, | |
| interface_provider=closed_loop.interface_provider, | |
| interface_model=closed_loop.interface_model, | |
| interface_narrate=closed_loop.interface_narrate, | |
| interface_translation_mode=closed_loop.interface_translation_mode, | |
| hero_max_game_steps=closed_loop.hero_max_game_steps, | |
| hero_max_tool_calls=closed_loop.hero_max_tool_calls, | |
| ), | |
| _bind_dm_rollout_reward( | |
| dm_hero_cleanliness_reward, | |
| artifacts_root=str(artifacts_root) if artifacts_root is not None else None, | |
| hero_policy=hero_policy, | |
| interface_provider=closed_loop.interface_provider, | |
| interface_model=closed_loop.interface_model, | |
| interface_narrate=closed_loop.interface_narrate, | |
| interface_translation_mode=closed_loop.interface_translation_mode, | |
| hero_max_game_steps=closed_loop.hero_max_game_steps, | |
| hero_max_tool_calls=closed_loop.hero_max_tool_calls, | |
| ), | |
| _bind_dm_reward_function( | |
| artifacts_root=str(artifacts_root) if artifacts_root is not None else None, | |
| hero_policy=hero_policy, | |
| interface_provider=closed_loop.interface_provider, | |
| interface_model=closed_loop.interface_model, | |
| interface_narrate=closed_loop.interface_narrate, | |
| interface_translation_mode=closed_loop.interface_translation_mode, | |
| hero_max_game_steps=closed_loop.hero_max_game_steps, | |
| hero_max_tool_calls=closed_loop.hero_max_tool_calls, | |
| ), | |
| ] | |
| reward_weights = [0.25, 0.20, 0.50, 0.45, 0.0, 0.0, 0.0, 1.0] | |
| overlong_reward = _make_named_overlong_reward(name="dm_overlong_reward", max_completion_len=config.max_completion_length) | |
| if overlong_reward is not None: | |
| reward_funcs.append(overlong_reward) | |
| reward_weights.append(0.15) | |
| model, peft_config, include_model_init_kwargs = _build_trainable_model(config) | |
| return GRPOTrainer( | |
| model=model, | |
| reward_funcs=reward_funcs, | |
| args=_build_grpo_config( | |
| config, | |
| max_tool_calling_iterations=None, | |
| chat_template_kwargs=chat_template_kwargs, | |
| reward_weights=reward_weights, | |
| include_model_init_kwargs=include_model_init_kwargs, | |
| ), | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| ) | |
| def create_hero_grpo_trainer( | |
| config: GRPOLaunchConfig, | |
| *, | |
| world_input: dict[str, Any] | None = None, | |
| artifacts_root: Path | None = None, | |
| interface_provider: str | None = None, | |
| interface_model: str | None = None, | |
| interface_narrate: bool = False, | |
| interface_translation_mode: str | None = None, | |
| max_game_steps: int = 40, | |
| max_tool_calls: int = 80, | |
| max_tool_calling_iterations: int = 32, | |
| ): | |
| _require_training_dependencies() | |
| rows = build_hero_grpo_dataset( | |
| num_prompts=config.num_prompts, | |
| world_input=world_input, | |
| max_game_steps=max_game_steps, | |
| max_tool_calls=max_tool_calls, | |
| ) | |
| dataset = Dataset.from_list(rows) | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=config.trust_remote_code) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer = _ensure_tool_response_schema(tokenizer) | |
| chat_template_kwargs = _chat_template_kwargs(tokenizer) | |
| environment_factory = lambda: HeroToolEnvironment( | |
| artifacts_root=artifacts_root, | |
| interface_provider=interface_provider, | |
| interface_model=interface_model, | |
| interface_narrate=interface_narrate, | |
| interface_translation_mode=interface_translation_mode, | |
| ) | |
| reward_funcs: list[Any] = [ | |
| hero_tool_format_reward, | |
| hero_action_semantics_reward, | |
| hero_reward_function, | |
| ] | |
| reward_weights = [0.40, 0.30, 1.0] | |
| overlong_reward = _make_named_overlong_reward( | |
| name="hero_overlong_reward", | |
| max_completion_len=config.max_completion_length, | |
| ) | |
| if overlong_reward is not None: | |
| reward_funcs.append(overlong_reward) | |
| reward_weights.append(0.15) | |
| model, peft_config, include_model_init_kwargs = _build_trainable_model(config) | |
| return GRPOTrainer( | |
| model=model, | |
| reward_funcs=reward_funcs, | |
| args=_build_grpo_config( | |
| config, | |
| max_tool_calling_iterations=max_tool_calling_iterations, | |
| chat_template_kwargs=chat_template_kwargs, | |
| reward_weights=reward_weights, | |
| include_model_init_kwargs=include_model_init_kwargs, | |
| ), | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| environment_factory=environment_factory, | |
| ) | |
| def run_dm_grpo( | |
| config: GRPOLaunchConfig, | |
| *, | |
| target_ratios: list[float] | None = None, | |
| artifacts_root: Path | None = None, | |
| closed_loop: DMClosedLoopConfig | None = None, | |
| ) -> Path: | |
| trainer = create_dm_grpo_trainer( | |
| config, | |
| target_ratios=target_ratios, | |
| artifacts_root=artifacts_root, | |
| closed_loop=closed_loop, | |
| ) | |
| trainer.train() | |
| trainer.save_model() | |
| return config.output_dir | |
| def run_hero_grpo( | |
| config: GRPOLaunchConfig, | |
| *, | |
| world_path: Path | None = None, | |
| artifacts_root: Path | None = None, | |
| interface_provider: str | None = None, | |
| interface_model: str | None = None, | |
| interface_narrate: bool = False, | |
| interface_translation_mode: str | None = None, | |
| max_game_steps: int = 40, | |
| max_tool_calls: int = 80, | |
| max_tool_calling_iterations: int = 32, | |
| ) -> Path: | |
| world_input = load_world(str(world_path)) if world_path is not None else None | |
| trainer = create_hero_grpo_trainer( | |
| config, | |
| world_input=world_input, | |
| artifacts_root=artifacts_root, | |
| interface_provider=interface_provider, | |
| interface_model=interface_model, | |
| interface_narrate=interface_narrate, | |
| interface_translation_mode=interface_translation_mode, | |
| max_game_steps=max_game_steps, | |
| max_tool_calls=max_tool_calls, | |
| max_tool_calling_iterations=max_tool_calling_iterations, | |
| ) | |
| trainer.train() | |
| trainer.save_model() | |
| return config.output_dir | |
| def _build_lora_config(config: GRPOLaunchConfig): | |
| _require_training_dependencies() | |
| return LoraConfig( | |
| r=config.rank, | |
| lora_alpha=config.alpha, | |
| lora_dropout=config.dropout, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules="all-linear", | |
| ) | |
| def _build_trainable_model(config: GRPOLaunchConfig) -> tuple[Any, Any | None, bool]: | |
| _require_training_dependencies() | |
| if config.resume_adapter_path is None: | |
| return config.model_name, _build_lora_config(config), True | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM | |
| adapter_path = Path(config.resume_adapter_path) | |
| if not adapter_path.exists(): | |
| raise FileNotFoundError(f"resume_adapter_path does not exist: {adapter_path}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.model_name, | |
| cache_dir=os.getenv("HF_HOME"), | |
| token=os.getenv("HF_TOKEN"), | |
| **_model_init_kwargs(config), | |
| ) | |
| model = PeftModel.from_pretrained(model, str(adapter_path), is_trainable=True) | |
| model.train() | |
| return model, None, False | |
| def _build_grpo_config( | |
| config: GRPOLaunchConfig, | |
| *, | |
| max_tool_calling_iterations: int | None, | |
| chat_template_kwargs: dict[str, Any] | None, | |
| reward_weights: list[float] | None, | |
| include_model_init_kwargs: bool = True, | |
| ): | |
| _require_training_dependencies() | |
| _require_vllm_if_requested(config) | |
| config.output_dir.mkdir(parents=True, exist_ok=True) | |
| report_to = ["wandb"] if config.use_wandb else [] | |
| model_init_kwargs = _model_init_kwargs(config) if include_model_init_kwargs else None | |
| return GRPOConfig( | |
| output_dir=str(config.output_dir), | |
| run_name=config.run_name, | |
| report_to=report_to, | |
| max_steps=config.max_steps, | |
| learning_rate=config.learning_rate, | |
| per_device_train_batch_size=config.per_device_train_batch_size, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| num_generations=config.num_generations, | |
| max_completion_length=config.max_completion_length, | |
| temperature=config.temperature, | |
| top_p=config.top_p, | |
| top_k=config.top_k, | |
| min_p=config.min_p, | |
| repetition_penalty=config.repetition_penalty, | |
| logging_steps=config.logging_steps, | |
| save_steps=config.save_steps, | |
| seed=config.seed, | |
| bf16=torch.cuda.is_available(), | |
| gradient_checkpointing=True, | |
| remove_unused_columns=False, | |
| loss_type=config.loss_type, | |
| importance_sampling_level=config.importance_sampling_level, | |
| use_transformers_paged=config.use_transformers_paged, | |
| cache_implementation=config.cache_implementation, | |
| use_vllm=config.use_vllm, | |
| vllm_mode=config.vllm_mode, | |
| vllm_gpu_memory_utilization=config.vllm_gpu_memory_utilization, | |
| vllm_enable_sleep_mode=config.vllm_enable_sleep_mode, | |
| log_completions=True, | |
| log_unique_prompts=True, | |
| num_completions_to_print=1, | |
| max_tool_calling_iterations=max_tool_calling_iterations, | |
| chat_template_kwargs=chat_template_kwargs, | |
| reward_weights=reward_weights, | |
| mask_truncated_completions=True, | |
| model_init_kwargs=model_init_kwargs, | |
| ) | |
| def _cached_dm_rollout_metrics( | |
| *, | |
| completion: Any, | |
| requested_ratio: float, | |
| trainer_state: Any, | |
| completion_index: int, | |
| hero_policy: Any, | |
| interface_provider: str | None, | |
| interface_model: str | None, | |
| interface_narrate: bool, | |
| interface_translation_mode: str | None, | |
| hero_max_game_steps: int, | |
| hero_max_tool_calls: int, | |
| artifacts_root: str | None, | |
| ) -> DMRolloutMetrics: | |
| global _DM_ROLLOUT_CACHE_STEP, _DM_ROLLOUT_CACHE | |
| step = int(getattr(trainer_state, "global_step", 0) or 0) | |
| if step != _DM_ROLLOUT_CACHE_STEP: | |
| _DM_ROLLOUT_CACHE_STEP = step | |
| _DM_ROLLOUT_CACHE = {} | |
| completion_text = _completion_text(completion) | |
| key = ( | |
| step, | |
| completion_index, | |
| requested_ratio, | |
| hashlib.sha1(completion_text.encode("utf-8")).hexdigest(), | |
| id(hero_policy), | |
| interface_provider, | |
| interface_model, | |
| interface_narrate, | |
| interface_translation_mode, | |
| hero_max_game_steps, | |
| hero_max_tool_calls, | |
| artifacts_root, | |
| ) | |
| cached = _DM_ROLLOUT_CACHE.get(key) | |
| if cached is not None: | |
| return cached | |
| metrics = _evaluate_dm_rollout( | |
| completion_text=completion_text, | |
| requested_ratio=requested_ratio, | |
| trainer_state=trainer_state, | |
| completion_index=completion_index, | |
| hero_policy=hero_policy, | |
| interface_provider=interface_provider, | |
| interface_model=interface_model, | |
| interface_narrate=interface_narrate, | |
| interface_translation_mode=interface_translation_mode, | |
| hero_max_game_steps=hero_max_game_steps, | |
| hero_max_tool_calls=hero_max_tool_calls, | |
| artifacts_root=artifacts_root, | |
| ) | |
| _DM_ROLLOUT_CACHE[key] = metrics | |
| return metrics | |
| def _evaluate_dm_rollout( | |
| *, | |
| completion_text: str, | |
| requested_ratio: float, | |
| trainer_state: Any, | |
| completion_index: int, | |
| hero_policy: Any, | |
| interface_provider: str | None, | |
| interface_model: str | None, | |
| interface_narrate: bool, | |
| interface_translation_mode: str | None, | |
| hero_max_game_steps: int, | |
| hero_max_tool_calls: int, | |
| artifacts_root: str | None, | |
| ) -> DMRolloutMetrics: | |
| try: | |
| world = _load_dm_world_definition(completion_text, allow_repair=True) | |
| except Exception as exc: | |
| return DMRolloutMetrics( | |
| reward=_compile_error_penalty(str(exc)), | |
| compile_error=str(exc), | |
| requested_ratio=requested_ratio, | |
| player_won=False, | |
| steps_taken=None, | |
| min_steps=None, | |
| ratio=None, | |
| efficiency_score=0.0, | |
| quality_score=0.0, | |
| invalid_command_count=0, | |
| wrong_submit_count=0, | |
| hero_player_won=False, | |
| hero_total_reward=0.0, | |
| hero_dense_return=0.0, | |
| hero_steps_taken=0, | |
| hero_tool_calls_total=0, | |
| hero_policy_error=None, | |
| ) | |
| interface_adapter = build_interface_adapter( | |
| resolve_interface_config( | |
| provider=interface_provider, # type: ignore[arg-type] | |
| model_name=interface_model, | |
| narrate_observations=interface_narrate, | |
| translation_mode=interface_translation_mode, # type: ignore[arg-type] | |
| ) | |
| ) | |
| env = DMEnvironment( | |
| artifacts_root=_dm_reward_artifacts_dir( | |
| artifacts_root=artifacts_root, | |
| trainer_state=trainer_state, | |
| completion_index=completion_index, | |
| ), | |
| interface_adapter=interface_adapter, | |
| ) | |
| runner = HeroRunner( | |
| policy=hero_policy, | |
| max_game_steps=hero_max_game_steps, | |
| max_tool_calls=hero_max_tool_calls, | |
| ) | |
| try: | |
| env.reset(difficulty_hint=requested_ratio) | |
| result = env.step(world, runner=runner) | |
| observation = result.observation | |
| reward = float(observation.reward or 0.0) | |
| if observation.compile_error is not None: | |
| reward = _compile_error_penalty(observation.compile_error) | |
| elif abs(world.meta.difficulty_target - requested_ratio) > 1e-6: | |
| reward -= 0.25 | |
| feedback = observation.feedback | |
| breakdown = observation.reward_breakdown | |
| hero_stats = runner.episode_stats | |
| return DMRolloutMetrics( | |
| reward=max(-1.0, reward), | |
| compile_error=observation.compile_error, | |
| requested_ratio=requested_ratio, | |
| player_won=bool(observation.player_won), | |
| steps_taken=observation.steps_taken, | |
| min_steps=observation.min_steps, | |
| ratio=observation.ratio, | |
| efficiency_score=0.0 if breakdown is None or breakdown.efficiency_score is None else float(breakdown.efficiency_score), | |
| quality_score=0.0 if breakdown is None else float(breakdown.quality_score), | |
| invalid_command_count=0 if feedback is None else int(feedback.invalid_command_count), | |
| wrong_submit_count=0 if feedback is None else int(feedback.wrong_submit_count), | |
| hero_player_won=bool(observation.player_won) if hero_stats is None else bool(hero_stats.player_won), | |
| hero_total_reward=0.0 if hero_stats is None else float(hero_stats.total_reward), | |
| hero_dense_return=0.0 if hero_stats is None else float(hero_stats.dense_return), | |
| hero_steps_taken=0 if hero_stats is None else int(hero_stats.steps_taken), | |
| hero_tool_calls_total=0 if hero_stats is None else int(hero_stats.tool_calls_total), | |
| hero_policy_error=runner.last_error, | |
| ) | |
| except Exception as exc: | |
| return DMRolloutMetrics( | |
| reward=_compile_error_penalty(str(exc)), | |
| compile_error=str(exc), | |
| requested_ratio=requested_ratio, | |
| player_won=False, | |
| steps_taken=None, | |
| min_steps=None, | |
| ratio=None, | |
| efficiency_score=0.0, | |
| quality_score=0.0, | |
| invalid_command_count=0, | |
| wrong_submit_count=0, | |
| hero_player_won=False, | |
| hero_total_reward=0.0, | |
| hero_dense_return=0.0, | |
| hero_steps_taken=0, | |
| hero_tool_calls_total=0, | |
| hero_policy_error=runner.last_error, | |
| ) | |
| def _model_init_kwargs(config: GRPOLaunchConfig) -> dict[str, Any]: | |
| model_init_kwargs: dict[str, Any] = { | |
| "trust_remote_code": config.trust_remote_code, | |
| } | |
| quantization_config = _build_quantization_config(config) | |
| if quantization_config is not None: | |
| model_init_kwargs["quantization_config"] = quantization_config | |
| if torch.cuda.is_available(): | |
| model_init_kwargs["torch_dtype"] = torch.bfloat16 | |
| return model_init_kwargs | |
| def _build_quantization_config(config: GRPOLaunchConfig): | |
| _require_training_dependencies() | |
| if not config.load_in_4bit or not torch.cuda.is_available(): | |
| return None | |
| return BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| def _completion_text(completion: Any) -> str: | |
| if isinstance(completion, str): | |
| return completion | |
| if isinstance(completion, list): | |
| parts: list[str] = [] | |
| for message in completion: | |
| if isinstance(message, dict) and message.get("role") == "assistant": | |
| content = message.get("content") | |
| if isinstance(content, str): | |
| parts.append(content) | |
| return "\n".join(parts) | |
| return str(completion) | |
| def _extract_json_object(text: str) -> str: | |
| json_text, _, _ = _extract_json_candidate_parts(text) | |
| if json_text is None: | |
| raise ValueError("Completion did not contain a JSON object.") | |
| return json_text | |
| def _extract_json_candidate_parts(text: str) -> tuple[str | None, str, str]: | |
| cleaned = _strip_code_fences(text).strip() | |
| span = _find_json_object_span(cleaned) | |
| if span is None: | |
| return None, cleaned, "" | |
| start, end = span | |
| return cleaned[start:end], cleaned[:start], cleaned[end:] | |
| def _try_parse_completion_json(text: str) -> Any | None: | |
| json_text, _, _ = _extract_json_candidate_parts(text) | |
| if json_text is None: | |
| return None | |
| try: | |
| return json.loads(json_text) | |
| except Exception: | |
| return None | |
| def _repair_dm_candidate_payload(payload: Any) -> Any: | |
| if isinstance(payload, list): | |
| return [_repair_dm_candidate_payload(item) for item in payload] | |
| if not isinstance(payload, dict): | |
| return payload | |
| node_type = payload.get("type") | |
| repaired: dict[str, Any] = {} | |
| for key, value in payload.items(): | |
| normalized_key = "requires_step_ids" if key == "requires_step_id" else key | |
| repaired[normalized_key] = _repair_dm_candidate_payload(value) | |
| requires_step_ids = repaired.get("requires_step_ids") | |
| if requires_step_ids is None and "requires_step_ids" in repaired: | |
| repaired["requires_step_ids"] = [] | |
| elif isinstance(requires_step_ids, str): | |
| repaired["requires_step_ids"] = [requires_step_ids] | |
| if "open" not in repaired and "is_open" in repaired: | |
| repaired["open"] = repaired.pop("is_open") | |
| if "locked" not in repaired and "is_locked" in repaired: | |
| repaired["locked"] = repaired.pop("is_locked") | |
| if node_type in {"container", "door"}: | |
| closed = repaired.pop("closed", None) | |
| if isinstance(closed, bool) and "open" not in repaired: | |
| repaired["open"] = not closed | |
| if node_type == "fixture": | |
| if "reveals_item_id" not in repaired and "reveal_item_id" in repaired: | |
| repaired["reveals_item_id"] = repaired.pop("reveal_item_id") | |
| if "reveals_readable_id" not in repaired and "reveal_readable_id" in repaired: | |
| repaired["reveals_readable_id"] = repaired.pop("reveal_readable_id") | |
| if node_type == "npc": | |
| if "requires_item_id" not in repaired and "trade_requires_item_id" in repaired: | |
| repaired["requires_item_id"] = repaired.pop("trade_requires_item_id") | |
| if "gives_item_id" not in repaired and "trade_item_id" in repaired: | |
| repaired["gives_item_id"] = repaired.pop("trade_item_id") | |
| if "gives_clue_id" not in repaired and "trade_clue_id" in repaired: | |
| repaired["gives_clue_id"] = repaired.pop("trade_clue_id") | |
| if "subtype" not in repaired and repaired.get("type") in _DM_ALLOWED_ITEM_TYPES and "start_node_id" in repaired: | |
| repaired["subtype"] = repaired.pop("type") | |
| if "id" not in repaired and "clue_id" in repaired and "text" in repaired: | |
| repaired["id"] = repaired.pop("clue_id") | |
| if "input_item_ids" not in repaired and "input_item_a_id" in repaired and "input_item_b_id" in repaired: | |
| repaired["input_item_ids"] = [repaired.pop("input_item_a_id"), repaired.pop("input_item_b_id")] | |
| if node_type == "container": | |
| repaired.pop("contains_items", None) | |
| if "output_item_id" in repaired and ( | |
| "input_item_ids" in repaired or ("input_item_a_id" in repaired and "input_item_b_id" in repaired) | |
| ): | |
| repaired.pop("label", None) | |
| repaired.pop("description", None) | |
| if node_type in {"location", "junction", "door"}: | |
| repaired.pop("parent_id", None) | |
| return repaired | |
| def _repair_dm_world_payload(payload: dict[str, Any]) -> dict[str, Any]: | |
| repaired = _repair_dm_candidate_payload(payload) | |
| if not isinstance(repaired, dict): | |
| return payload | |
| meta = repaired.get("meta") | |
| if not isinstance(meta, dict): | |
| meta = {} | |
| else: | |
| meta = dict(meta) | |
| title = meta.get("title") | |
| if not isinstance(title, str) or not title.strip(): | |
| meta["title"] = _infer_dm_world_title(repaired) | |
| start_node_id = meta.get("start_node_id") | |
| if not isinstance(start_node_id, str) or not start_node_id: | |
| inferred_start = _infer_dm_start_node_id(repaired.get("nodes")) | |
| if inferred_start is not None: | |
| meta["start_node_id"] = inferred_start | |
| win_condition = meta.get("win_condition") | |
| if not isinstance(win_condition, dict): | |
| win_condition = {} | |
| else: | |
| win_condition = dict(win_condition) | |
| if not isinstance(win_condition.get("type"), str) or not win_condition.get("type"): | |
| win_condition["type"] = "deduce" | |
| if not isinstance(win_condition.get("target_npc_id"), str) or not win_condition.get("target_npc_id"): | |
| inferred_guardian = _infer_dm_guardian_npc_id(repaired) | |
| if inferred_guardian is not None: | |
| win_condition["target_npc_id"] = inferred_guardian | |
| if not isinstance(win_condition.get("answer_string"), str) or not win_condition.get("answer_string"): | |
| inferred_answer = _infer_dm_answer_string(repaired.get("quest_chain")) | |
| if inferred_answer: | |
| win_condition["answer_string"] = inferred_answer | |
| if win_condition: | |
| meta["win_condition"] = win_condition | |
| _repair_guardian_trade_fields(repaired, guardian_id=win_condition.get("target_npc_id")) | |
| _repair_submit_actions(repaired) | |
| _repair_door_lock_keys_from_edges(repaired) | |
| _repair_missing_item_references(repaired) | |
| _repair_produced_item_placements(repaired, default_start_node_id=meta.get("start_node_id")) | |
| _repair_required_key_item_subtypes(repaired) | |
| _repair_duplicate_recipe_ids(repaired) | |
| _repair_guardian_room_access(repaired, guardian_id=win_condition.get("target_npc_id"), start_node_id=meta.get("start_node_id")) | |
| _repair_missing_readable_clue_ids(repaired) | |
| _repair_missing_clue_sources(repaired, guardian_id=win_condition.get("target_npc_id")) | |
| _repair_take_action_aliases(repaired) | |
| _repair_take_sources_from_room_prereqs(repaired) | |
| _repair_locked_room_entry_steps(repaired) | |
| _repair_missing_take_steps(repaired) | |
| _repair_guardian_ending( | |
| repaired, | |
| guardian_id=win_condition.get("target_npc_id"), | |
| answer_string=win_condition.get("answer_string"), | |
| ) | |
| _repair_guardian_room_access(repaired, guardian_id=win_condition.get("target_npc_id"), start_node_id=meta.get("start_node_id")) | |
| repaired["meta"] = meta | |
| return repaired | |
| def _infer_dm_world_title(payload: dict[str, Any]) -> str: | |
| meta = payload.get("meta") | |
| if isinstance(meta, dict): | |
| for key in ("name", "world_name"): | |
| value = meta.get(key) | |
| if isinstance(value, str) and value.strip(): | |
| return value.strip() | |
| nodes = payload.get("nodes") | |
| if isinstance(nodes, list): | |
| for node in nodes: | |
| if not isinstance(node, dict) or node.get("type") not in {"location", "junction"}: | |
| continue | |
| label = node.get("label") | |
| if isinstance(label, str) and label.strip(): | |
| return f"The {label.strip()}" | |
| return "The Hidden Vault" | |
| def _infer_dm_start_node_id(nodes: Any) -> str | None: | |
| if not isinstance(nodes, list): | |
| return None | |
| for node in nodes: | |
| if not isinstance(node, dict) or node.get("type") not in {"location", "junction"}: | |
| continue | |
| node_id = node.get("id") | |
| if isinstance(node_id, str) and node_id: | |
| return node_id | |
| return None | |
| def _infer_dm_guardian_npc_id(payload: dict[str, Any]) -> str | None: | |
| quest_chain = payload.get("quest_chain") | |
| if isinstance(quest_chain, list): | |
| for step in reversed(quest_chain): | |
| action = step.get("action") if isinstance(step, dict) else None | |
| npc_id = _extract_single_action_argument(action, "talk") | |
| if npc_id: | |
| return npc_id | |
| nodes = payload.get("nodes") | |
| if not isinstance(nodes, list): | |
| return None | |
| first_npc_id: str | None = None | |
| for node in nodes: | |
| if not isinstance(node, dict) or node.get("type") != "npc": | |
| continue | |
| node_id = node.get("id") | |
| if not isinstance(node_id, str) or not node_id: | |
| continue | |
| if first_npc_id is None: | |
| first_npc_id = node_id | |
| if "guardian" in node_id: | |
| return node_id | |
| return first_npc_id | |
| def _infer_dm_answer_string(quest_chain: Any) -> str | None: | |
| if not isinstance(quest_chain, list): | |
| return None | |
| for step in reversed(quest_chain): | |
| action = step.get("action") if isinstance(step, dict) else None | |
| answer = _extract_single_action_argument(action, "submit") | |
| if answer is None: | |
| continue | |
| normalized = normalize_answer_text(answer) | |
| if normalized: | |
| return normalized | |
| return None | |
| def _repair_missing_readable_clue_ids(payload: dict[str, Any]) -> None: | |
| nodes = payload.get("nodes") | |
| clues = payload.get("clues") | |
| if not isinstance(nodes, list) or not isinstance(clues, list): | |
| return | |
| clue_ids = [clue.get("id") for clue in clues if isinstance(clue, dict) and isinstance(clue.get("id"), str)] | |
| if not clue_ids: | |
| return | |
| used_clue_ids = { | |
| node.get("clue_id") | |
| for node in nodes | |
| if isinstance(node, dict) and node.get("type") == "readable" and isinstance(node.get("clue_id"), str) | |
| } | |
| available_clue_ids = [clue_id for clue_id in clue_ids if clue_id not in used_clue_ids] | |
| if not available_clue_ids: | |
| return | |
| for node in nodes: | |
| if not isinstance(node, dict) or node.get("type") != "readable" or node.get("clue_id"): | |
| continue | |
| if not available_clue_ids: | |
| return | |
| node["clue_id"] = available_clue_ids.pop(0) | |
| def _repair_guardian_trade_fields(payload: dict[str, Any], *, guardian_id: Any) -> None: | |
| if not isinstance(guardian_id, str) or not guardian_id: | |
| return | |
| nodes = payload.get("nodes") | |
| if not isinstance(nodes, list): | |
| return | |
| for node in nodes: | |
| if not isinstance(node, dict) or node.get("type") != "npc" or node.get("id") != guardian_id: | |
| continue | |
| node["requires_item_id"] = None | |
| node["gives_item_id"] = None | |
| node["gives_clue_id"] = None | |
| return | |
| def _repair_submit_actions(payload: dict[str, Any]) -> None: | |
| quest_chain = payload.get("quest_chain") | |
| if not isinstance(quest_chain, list): | |
| return | |
| for step in quest_chain: | |
| if not isinstance(step, dict): | |
| continue | |
| action = step.get("action") | |
| answer = _extract_single_action_argument(action, "submit") | |
| if answer is None: | |
| continue | |
| if action == f'submit("{answer}")': | |
| continue | |
| step["action"] = f'submit("{normalize_answer_text(answer)}")' | |
| def _repair_door_lock_keys_from_edges(payload: dict[str, Any]) -> None: | |
| nodes = payload.get("nodes") | |
| edges = payload.get("edges") | |
| if not isinstance(nodes, list) or not isinstance(edges, list): | |
| return | |
| door_ids = [ | |
| node.get("id") | |
| for node in nodes | |
| if isinstance(node, dict) and node.get("type") == "door" and isinstance(node.get("id"), str) | |
| ] | |
| sole_door_id = door_ids[0] if len(door_ids) == 1 else None | |
| inferred_keys: dict[str, str] = {} | |
| for edge in edges: | |
| if not isinstance(edge, dict): | |
| continue | |
| door_node_id = edge.get("door_node_id") | |
| required_item_id = edge.get("required_item_id") | |
| if sole_door_id is not None and isinstance(door_node_id, str) and door_node_id not in door_ids: | |
| edge["door_node_id"] = sole_door_id | |
| door_node_id = sole_door_id | |
| if not isinstance(door_node_id, str) or not isinstance(required_item_id, str): | |
| continue | |
| existing_key = inferred_keys.get(door_node_id) | |
| if existing_key is None or existing_key == required_item_id: | |
| inferred_keys[door_node_id] = required_item_id | |
| if not inferred_keys: | |
| return | |
| for node in nodes: | |
| if not isinstance(node, dict) or node.get("type") != "door": | |
| continue | |
| door_id = node.get("id") | |
| if isinstance(door_id, str) and door_id in inferred_keys: | |
| node["lock_key_id"] = inferred_keys[door_id] | |
| def _repair_required_key_item_subtypes(payload: dict[str, Any]) -> None: | |
| items = payload.get("items") | |
| edges = payload.get("edges") | |
| nodes = payload.get("nodes") | |
| if not isinstance(items, list): | |
| return | |
| required_key_ids: set[str] = set() | |
| if isinstance(edges, list): | |
| for edge in edges: | |
| if not isinstance(edge, dict): | |
| continue | |
| required_item_id = edge.get("required_item_id") | |
| if isinstance(required_item_id, str) and required_item_id: | |
| required_key_ids.add(required_item_id) | |
| if isinstance(nodes, list): | |
| for node in nodes: | |
| if not isinstance(node, dict): | |
| continue | |
| lock_key_id = node.get("lock_key_id") | |
| if isinstance(lock_key_id, str) and lock_key_id: | |
| required_key_ids.add(lock_key_id) | |
| if not required_key_ids: | |
| return | |
| for item in items: | |
| if not isinstance(item, dict): | |
| continue | |
| item_id = item.get("id") | |
| if isinstance(item_id, str) and item_id in required_key_ids: | |
| item["subtype"] = "key" | |
| def _repair_duplicate_recipe_ids(payload: dict[str, Any]) -> None: | |
| recipes = payload.get("recipes") | |
| if not isinstance(recipes, list): | |
| return | |
| protected_ids: set[str] = set() | |
| for key in ("nodes", "items", "clues", "quest_chain"): | |
| values = payload.get(key) | |
| if not isinstance(values, list): | |
| continue | |
| for value in values: | |
| if not isinstance(value, dict): | |
| continue | |
| id_key = "step_id" if key == "quest_chain" else "id" | |
| value_id = value.get(id_key) | |
| if isinstance(value_id, str) and value_id: | |
| protected_ids.add(value_id) | |
| recipe_ids: set[str] = set() | |
| for recipe in recipes: | |
| if not isinstance(recipe, dict): | |
| continue | |
| recipe_id = recipe.get("id") | |
| if not isinstance(recipe_id, str) or not recipe_id: | |
| continue | |
| if recipe_id not in protected_ids and recipe_id not in recipe_ids: | |
| recipe_ids.add(recipe_id) | |
| continue | |
| new_recipe_id = _unique_world_id(recipe_id, protected_ids | recipe_ids) | |
| recipe["id"] = new_recipe_id | |
| recipe_ids.add(new_recipe_id) | |
| def _repair_guardian_room_access(payload: dict[str, Any], *, guardian_id: Any, start_node_id: Any) -> None: | |
| if not isinstance(guardian_id, str) or not guardian_id: | |
| return | |
| nodes = payload.get("nodes") | |
| edges = payload.get("edges") | |
| quest_chain = payload.get("quest_chain") | |
| if not isinstance(nodes, list) or not isinstance(edges, list): | |
| return | |
| reachable_rooms = _reachable_passage_room_ids(payload, start_node_id=start_node_id) | |
| if not reachable_rooms: | |
| return | |
| preferred_room_id = _infer_guardian_talk_room_from_quest(quest_chain, guardian_id=guardian_id) | |
| if preferred_room_id not in reachable_rooms: | |
| preferred_room_id = next(iter(sorted(reachable_rooms))) | |
| for node in nodes: | |
| if not isinstance(node, dict) or node.get("type") != "npc" or node.get("id") != guardian_id: | |
| continue | |
| parent_id = node.get("parent_id") | |
| if isinstance(parent_id, str) and parent_id in reachable_rooms: | |
| return | |
| node["parent_id"] = preferred_room_id | |
| current_guardian_room = _infer_guardian_talk_room_from_quest(quest_chain, guardian_id=guardian_id) | |
| if current_guardian_room != preferred_room_id: | |
| _insert_quest_step_before_guardian_talk( | |
| quest_chain, | |
| guardian_id=guardian_id, | |
| step_id_base=f"go_{preferred_room_id}", | |
| description=f"Go to {_humanize_identifier(preferred_room_id).lower()}.", | |
| action=f"go({preferred_room_id})", | |
| ) | |
| return | |
| def _repair_missing_item_references(payload: dict[str, Any]) -> None: | |
| items = payload.get("items") | |
| nodes = payload.get("nodes") | |
| edges = payload.get("edges") | |
| if not isinstance(items, list): | |
| return | |
| existing_item_ids = { | |
| item.get("id") | |
| for item in items | |
| if isinstance(item, dict) and isinstance(item.get("id"), str) and item.get("id") | |
| } | |
| quest_chain = payload.get("quest_chain") | |
| def ensure_item(item_id: Any, *, subtype: str, start_node_id: str | None) -> None: | |
| if not isinstance(item_id, str) or not item_id or item_id in existing_item_ids: | |
| return | |
| inferred_start_node_id = _infer_item_start_node_from_quest(quest_chain, item_id) or start_node_id | |
| items.append( | |
| { | |
| "id": item_id, | |
| "label": _humanize_identifier(item_id), | |
| "description": f"A {_humanize_identifier(item_id).lower()} needed to solve the dungeon.", | |
| "subtype": subtype, | |
| "start_node_id": inferred_start_node_id, | |
| } | |
| ) | |
| existing_item_ids.add(item_id) | |
| default_start_node_id = _infer_dm_start_node_id(payload.get("nodes")) | |
| if isinstance(edges, list): | |
| for edge in edges: | |
| if not isinstance(edge, dict): | |
| continue | |
| ensure_item(edge.get("required_item_id"), subtype="key", start_node_id=default_start_node_id) | |
| if not isinstance(nodes, list): | |
| return | |
| for node in nodes: | |
| if not isinstance(node, dict): | |
| continue | |
| node_type = node.get("type") | |
| if node_type in {"container", "door"}: | |
| ensure_item(node.get("lock_key_id"), subtype="key", start_node_id=default_start_node_id) | |
| elif node_type == "readable": | |
| ensure_item( | |
| node.get("requires_item_id"), | |
| subtype="puzzle", | |
| start_node_id=_node_room_start_node_id(node, default_start_node_id), | |
| ) | |
| elif node_type == "fixture": | |
| ensure_item( | |
| node.get("requires_item_id"), | |
| subtype="puzzle", | |
| start_node_id=_node_room_start_node_id(node, default_start_node_id), | |
| ) | |
| ensure_item(node.get("reveals_item_id"), subtype="puzzle", start_node_id=None) | |
| elif node_type == "npc": | |
| ensure_item( | |
| node.get("requires_item_id"), | |
| subtype="puzzle", | |
| start_node_id=_node_room_start_node_id(node, default_start_node_id), | |
| ) | |
| ensure_item(node.get("gives_item_id"), subtype="puzzle", start_node_id=None) | |
| recipes = payload.get("recipes") | |
| if not isinstance(recipes, list): | |
| return | |
| for recipe in recipes: | |
| if not isinstance(recipe, dict): | |
| continue | |
| input_ids = recipe.get("input_item_ids") | |
| if isinstance(input_ids, list): | |
| for item_id in input_ids: | |
| ensure_item(item_id, subtype="puzzle", start_node_id=default_start_node_id) | |
| ensure_item(recipe.get("output_item_id"), subtype="puzzle", start_node_id=None) | |
| def _repair_produced_item_placements(payload: dict[str, Any], *, default_start_node_id: Any) -> None: | |
| items = payload.get("items") | |
| if not isinstance(items, list): | |
| return | |
| produced_item_ids: set[str] = set() | |
| recipes = payload.get("recipes") | |
| if isinstance(recipes, list): | |
| for recipe in recipes: | |
| if not isinstance(recipe, dict): | |
| continue | |
| output_item_id = recipe.get("output_item_id") | |
| if isinstance(output_item_id, str) and output_item_id: | |
| produced_item_ids.add(output_item_id) | |
| nodes = payload.get("nodes") | |
| if isinstance(nodes, list): | |
| for node in nodes: | |
| if not isinstance(node, dict): | |
| continue | |
| if node.get("type") == "npc": | |
| gives_item_id = node.get("gives_item_id") | |
| if isinstance(gives_item_id, str) and gives_item_id: | |
| produced_item_ids.add(gives_item_id) | |
| elif node.get("type") == "fixture": | |
| reveals_item_id = node.get("reveals_item_id") | |
| if isinstance(reveals_item_id, str) and reveals_item_id: | |
| produced_item_ids.add(reveals_item_id) | |
| start_node_id = default_start_node_id if isinstance(default_start_node_id, str) and default_start_node_id else None | |
| for item in items: | |
| if not isinstance(item, dict): | |
| continue | |
| item_id = item.get("id") | |
| if not isinstance(item_id, str) or not item_id: | |
| continue | |
| if item_id in produced_item_ids: | |
| item["start_node_id"] = None | |
| elif item.get("start_node_id") is None and start_node_id is not None: | |
| item["start_node_id"] = start_node_id | |
| def _repair_missing_clue_sources(payload: dict[str, Any], *, guardian_id: Any) -> None: | |
| clues = payload.get("clues") | |
| nodes = payload.get("nodes") | |
| items = payload.get("items") | |
| quest_chain = payload.get("quest_chain") | |
| if not isinstance(clues, list) or not isinstance(nodes, list): | |
| return | |
| clue_text_by_id = { | |
| clue.get("id"): clue.get("text") | |
| for clue in clues | |
| if isinstance(clue, dict) and isinstance(clue.get("id"), str) | |
| } | |
| if not clue_text_by_id: | |
| return | |
| sourced_clue_ids = set() | |
| room_ids: set[str] = set() | |
| guardian_room_id: str | None = None | |
| for node in nodes: | |
| if not isinstance(node, dict): | |
| continue | |
| node_type = node.get("type") | |
| if node_type in {"location", "junction"}: | |
| node_id = node.get("id") | |
| if isinstance(node_id, str) and node_id: | |
| room_ids.add(node_id) | |
| elif node_type == "readable": | |
| clue_id = node.get("clue_id") | |
| if isinstance(clue_id, str) and clue_id: | |
| sourced_clue_ids.add(clue_id) | |
| elif node_type == "npc": | |
| clue_id = node.get("gives_clue_id") | |
| if isinstance(clue_id, str) and clue_id: | |
| sourced_clue_ids.add(clue_id) | |
| if isinstance(guardian_id, str) and guardian_id and node.get("id") == guardian_id: | |
| parent_id = node.get("parent_id") | |
| if isinstance(parent_id, str) and parent_id: | |
| guardian_room_id = parent_id | |
| missing_clue_ids = [clue_id for clue_id in clue_text_by_id if clue_id not in sourced_clue_ids] | |
| if not missing_clue_ids: | |
| return | |
| target_room_id = guardian_room_id or _infer_dm_start_node_id(nodes) | |
| if not isinstance(target_room_id, str) or target_room_id not in room_ids: | |
| target_room_id = next(iter(room_ids), None) | |
| if target_room_id is None: | |
| return | |
| gating_item_id = _select_synthetic_clue_gate_item_id(items, quest_chain) | |
| if gating_item_id is None: | |
| if not isinstance(items, list): | |
| return | |
| gating_item_id = "inspection_lens" | |
| items.append( | |
| { | |
| "id": gating_item_id, | |
| "label": "Inspection Lens", | |
| "description": "A careful lens for reading faint inscriptions.", | |
| "subtype": "puzzle", | |
| "start_node_id": _infer_dm_start_node_id(nodes), | |
| } | |
| ) | |
| existing_node_ids = { | |
| node.get("id") | |
| for node in nodes | |
| if isinstance(node, dict) and isinstance(node.get("id"), str) and node.get("id") | |
| } | |
| existing_safe_labels = { | |
| parser_safe_text(node.get("label")) | |
| for node in nodes | |
| if isinstance(node, dict) and isinstance(node.get("label"), str) and node.get("label") | |
| } | |
| synthetic_step_ids: list[str] = [] | |
| for clue_id in missing_clue_ids: | |
| readable_id = _unique_world_id(f"{clue_id}_inscription", existing_node_ids) | |
| label = _unique_world_label(f"{_humanize_identifier(clue_id)} Inscription", existing_safe_labels) | |
| nodes.append( | |
| { | |
| "id": readable_id, | |
| "type": "readable", | |
| "label": label, | |
| "description": f"A {label.lower()} can only be deciphered with the right tool.", | |
| "parent_id": target_room_id, | |
| "clue_id": clue_id, | |
| "requires_item_id": gating_item_id, | |
| "consumes_item": False, | |
| "text_content": clue_text_by_id[clue_id] or f"A fragment about {_humanize_identifier(clue_id).lower()}.", | |
| } | |
| ) | |
| step_id = _insert_quest_step_before_guardian_talk( | |
| quest_chain, | |
| guardian_id=guardian_id, | |
| step_id_base=f"inspect_{readable_id}", | |
| description=f"Inspect the {label.lower()}.", | |
| action=( | |
| f"use({gating_item_id},{readable_id})" | |
| if isinstance(gating_item_id, str) and gating_item_id | |
| else f"read({readable_id})" | |
| ), | |
| ) | |
| if step_id is not None: | |
| synthetic_step_ids.append(step_id) | |
| def _repair_take_action_aliases(payload: dict[str, Any]) -> None: | |
| quest_chain = payload.get("quest_chain") | |
| nodes = payload.get("nodes") | |
| if not isinstance(quest_chain, list) or not isinstance(nodes, list): | |
| return | |
| fixture_by_id: dict[str, dict[str, Any]] = {} | |
| npc_by_id: dict[str, dict[str, Any]] = {} | |
| for node in nodes: | |
| if not isinstance(node, dict): | |
| continue | |
| node_id = node.get("id") | |
| if not isinstance(node_id, str) or not node_id: | |
| continue | |
| if node.get("type") == "fixture": | |
| fixture_by_id[node_id] = node | |
| elif node.get("type") == "npc": | |
| npc_by_id[node_id] = node | |
| for step in quest_chain: | |
| if not isinstance(step, dict): | |
| continue | |
| arguments = _extract_action_arguments(step.get("action"), "take") | |
| if arguments is None or len(arguments) != 2: | |
| continue | |
| item_id, source_id = arguments | |
| fixture = fixture_by_id.get(source_id) | |
| if fixture is not None and fixture.get("reveals_item_id") == item_id: | |
| parent_id = fixture.get("parent_id") | |
| if isinstance(parent_id, str) and parent_id: | |
| step["action"] = f"take({item_id},{parent_id})" | |
| continue | |
| npc = npc_by_id.get(source_id) | |
| if npc is None or npc.get("gives_item_id") != item_id: | |
| continue | |
| required_item_id = npc.get("requires_item_id") | |
| if isinstance(required_item_id, str) and required_item_id: | |
| step["action"] = f"give({required_item_id},{source_id})" | |
| def _repair_take_sources_from_room_prereqs(payload: dict[str, Any]) -> None: | |
| quest_chain = payload.get("quest_chain") | |
| items = payload.get("items") | |
| nodes = payload.get("nodes") | |
| if not isinstance(quest_chain, list) or not isinstance(items, list): | |
| return | |
| node_types: dict[str, str] = {} | |
| if isinstance(nodes, list): | |
| for node in nodes: | |
| if not isinstance(node, dict): | |
| continue | |
| node_id = node.get("id") | |
| node_type = node.get("type") | |
| if isinstance(node_id, str) and isinstance(node_type, str): | |
| node_types[node_id] = node_type | |
| item_by_id = { | |
| item.get("id"): item | |
| for item in items | |
| if isinstance(item, dict) and isinstance(item.get("id"), str) and item.get("id") | |
| } | |
| step_by_id = { | |
| step.get("step_id"): step | |
| for step in quest_chain | |
| if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id") | |
| } | |
| for step in quest_chain: | |
| if not isinstance(step, dict): | |
| continue | |
| arguments = _extract_action_arguments(step.get("action"), "take") | |
| if arguments is None or len(arguments) != 2: | |
| continue | |
| item_id, source_id = arguments | |
| if node_types.get(source_id) == "container": | |
| continue | |
| requires_step_ids = step.get("requires_step_ids") | |
| if not isinstance(requires_step_ids, list): | |
| continue | |
| required_room_id: str | None = None | |
| for dependency in requires_step_ids: | |
| if not isinstance(dependency, str): | |
| continue | |
| dependency_step = step_by_id.get(dependency) | |
| if not isinstance(dependency_step, dict): | |
| continue | |
| room_id = _extract_single_action_argument(dependency_step.get("action"), "go") | |
| if room_id: | |
| required_room_id = room_id | |
| if required_room_id is None or required_room_id == source_id: | |
| continue | |
| step["action"] = f"take({item_id},{required_room_id})" | |
| item = item_by_id.get(item_id) | |
| if isinstance(item, dict): | |
| item["start_node_id"] = required_room_id | |
| def _repair_missing_take_steps(payload: dict[str, Any]) -> None: | |
| quest_chain = payload.get("quest_chain") | |
| items = payload.get("items") | |
| nodes = payload.get("nodes") | |
| recipes = payload.get("recipes") | |
| if not isinstance(quest_chain, list) or not isinstance(items, list): | |
| return | |
| item_start_nodes = { | |
| item.get("id"): item.get("start_node_id") | |
| for item in items | |
| if isinstance(item, dict) and isinstance(item.get("id"), str) | |
| } | |
| produced_item_ids = set() | |
| recipe_outputs: dict[frozenset[str], str] = {} | |
| if isinstance(recipes, list): | |
| for recipe in recipes: | |
| if not isinstance(recipe, dict): | |
| continue | |
| output_item_id = recipe.get("output_item_id") | |
| input_item_ids = recipe.get("input_item_ids") | |
| if isinstance(output_item_id, str) and output_item_id: | |
| produced_item_ids.add(output_item_id) | |
| if isinstance(output_item_id, str) and isinstance(input_item_ids, list) and len(input_item_ids) == 2: | |
| recipe_outputs[frozenset(str(item_id) for item_id in input_item_ids)] = output_item_id | |
| npc_rewards: dict[str, str] = {} | |
| if isinstance(nodes, list): | |
| for node in nodes: | |
| if not isinstance(node, dict) or node.get("type") != "npc": | |
| continue | |
| npc_id = node.get("id") | |
| gives_item_id = node.get("gives_item_id") | |
| if isinstance(npc_id, str) and npc_id and isinstance(gives_item_id, str) and gives_item_id: | |
| produced_item_ids.add(gives_item_id) | |
| npc_rewards[npc_id] = gives_item_id | |
| for node in nodes: | |
| if not isinstance(node, dict) or node.get("type") != "fixture": | |
| continue | |
| reveals_item_id = node.get("reveals_item_id") | |
| if isinstance(reveals_item_id, str) and reveals_item_id: | |
| produced_item_ids.add(reveals_item_id) | |
| inventory: set[str] = set() | |
| step_by_id = { | |
| step.get("step_id"): step | |
| for step in quest_chain | |
| if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id") | |
| } | |
| index = 0 | |
| while index < len(quest_chain): | |
| step = quest_chain[index] | |
| if not isinstance(step, dict): | |
| index += 1 | |
| continue | |
| required_item_ids = _quest_required_item_ids(step.get("action")) | |
| inserted_step = False | |
| for item_id in required_item_ids: | |
| if item_id in inventory or item_id in produced_item_ids: | |
| continue | |
| source_node_id = _infer_room_prereq_for_step(step, step_by_id) or item_start_nodes.get(item_id) | |
| if not isinstance(source_node_id, str) or not source_node_id: | |
| continue | |
| new_step_id = _insert_quest_step_before_index( | |
| quest_chain, | |
| index=index, | |
| step_id_base=f"take_{item_id}", | |
| description=f"Take the {_humanize_identifier(item_id).lower()}.", | |
| action=f"take({item_id},{source_node_id})", | |
| allow_existing_action=True, | |
| ) | |
| if new_step_id is not None: | |
| inventory.add(item_id) | |
| item = next( | |
| ( | |
| candidate | |
| for candidate in items | |
| if isinstance(candidate, dict) and candidate.get("id") == item_id | |
| ), | |
| None, | |
| ) | |
| if isinstance(item, dict): | |
| item["start_node_id"] = source_node_id | |
| inserted_step = True | |
| index += 1 | |
| if inserted_step: | |
| step = quest_chain[index] | |
| if not isinstance(step, dict): | |
| index += 1 | |
| continue | |
| arguments = _extract_action_arguments(step.get("action"), "take") | |
| if arguments is not None and len(arguments) == 2: | |
| inventory.add(arguments[0]) | |
| index += 1 | |
| continue | |
| arguments = _extract_action_arguments(step.get("action"), "give") | |
| if arguments is not None and len(arguments) == 2: | |
| inventory.discard(arguments[0]) | |
| rewarded_item_id = npc_rewards.get(arguments[1]) | |
| if rewarded_item_id: | |
| inventory.add(rewarded_item_id) | |
| index += 1 | |
| continue | |
| arguments = _extract_action_arguments(step.get("action"), "combine") | |
| if arguments is not None and len(arguments) == 2: | |
| inventory.discard(arguments[0]) | |
| inventory.discard(arguments[1]) | |
| output_item_id = recipe_outputs.get(frozenset(arguments)) | |
| if output_item_id: | |
| inventory.add(output_item_id) | |
| index += 1 | |
| continue | |
| index += 1 | |
| def _repair_guardian_ending(payload: dict[str, Any], *, guardian_id: Any, answer_string: Any) -> None: | |
| quest_chain = payload.get("quest_chain") | |
| if not isinstance(quest_chain, list) or not quest_chain: | |
| return | |
| submit_index: int | None = None | |
| for index in range(len(quest_chain) - 1, -1, -1): | |
| step = quest_chain[index] | |
| if isinstance(step, dict) and _extract_single_action_argument(step.get("action"), "submit") is not None: | |
| submit_index = index | |
| break | |
| if submit_index is None: | |
| return | |
| submit_step = quest_chain[submit_index] | |
| if not isinstance(submit_step, dict): | |
| return | |
| if isinstance(answer_string, str) and answer_string: | |
| submit_step["action"] = f'submit("{normalize_answer_text(answer_string)}")' | |
| talk_index = _guardian_talk_step_index(quest_chain, guardian_id=guardian_id) | |
| if submit_index == len(quest_chain) - 1 and submit_index > 0: | |
| penultimate = quest_chain[submit_index - 1] | |
| if isinstance(penultimate, dict) and _extract_single_action_argument(penultimate.get("action"), "talk") == guardian_id: | |
| return | |
| if talk_index is None: | |
| new_step_id = _insert_quest_step_before_index( | |
| quest_chain, | |
| index=submit_index, | |
| step_id_base=f"talk_{guardian_id}", | |
| description=f"Speak to the {_humanize_identifier(str(guardian_id)).lower()}.", | |
| action=f"talk({guardian_id})", | |
| allow_existing_action=True, | |
| ) | |
| if new_step_id is not None: | |
| submit_step["requires_step_ids"] = [new_step_id] | |
| return | |
| talk_step = quest_chain[talk_index] | |
| if not isinstance(talk_step, dict): | |
| return | |
| if talk_index != submit_index - 1: | |
| new_step_id = _insert_quest_step_before_index( | |
| quest_chain, | |
| index=submit_index, | |
| step_id_base=talk_step.get("step_id") or f"talk_{guardian_id}", | |
| description=talk_step.get("description") or f"Speak to the {_humanize_identifier(str(guardian_id)).lower()}.", | |
| action=talk_step.get("action") or f"talk({guardian_id})", | |
| allow_existing_action=True, | |
| ) | |
| if new_step_id is not None: | |
| submit_step["requires_step_ids"] = [new_step_id] | |
| def _repair_locked_room_entry_steps(payload: dict[str, Any]) -> None: | |
| quest_chain = payload.get("quest_chain") | |
| edges = payload.get("edges") | |
| meta = payload.get("meta") | |
| if not isinstance(quest_chain, list) or not isinstance(edges, list) or not isinstance(meta, dict): | |
| return | |
| edge_by_rooms = { | |
| (edge.get("from_node_id"), edge.get("to_node_id")): edge | |
| for edge in edges | |
| if isinstance(edge, dict) | |
| } | |
| start_node_id = meta.get("start_node_id") | |
| step_by_id = { | |
| step.get("step_id"): step | |
| for step in quest_chain | |
| if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id") | |
| } | |
| index = 0 | |
| while index < len(quest_chain): | |
| step = quest_chain[index] | |
| if not isinstance(step, dict): | |
| index += 1 | |
| continue | |
| target_room_id = _extract_single_action_argument(step.get("action"), "go") | |
| if target_room_id is None: | |
| index += 1 | |
| continue | |
| current_room_id = _infer_room_prereq_for_step(step, step_by_id) or ( | |
| start_node_id if isinstance(start_node_id, str) else None | |
| ) | |
| if current_room_id is None: | |
| index += 1 | |
| continue | |
| edge = edge_by_rooms.get((current_room_id, target_room_id)) | |
| if not isinstance(edge, dict) or not isinstance(edge.get("door_node_id"), str): | |
| index += 1 | |
| continue | |
| door_id = edge.get("door_node_id") | |
| key_id = edge.get("required_item_id") | |
| inserted = False | |
| if isinstance(door_id, str) and isinstance(key_id, str): | |
| unlock_action = f"unlock({door_id},{key_id})" | |
| if not _action_exists_before_index(quest_chain, unlock_action, index): | |
| if _insert_quest_step_before_index( | |
| quest_chain, | |
| index=index, | |
| step_id_base=f"unlock_{door_id}", | |
| description=f"Unlock the {_humanize_identifier(door_id).lower()}.", | |
| action=unlock_action, | |
| allow_existing_action=True, | |
| ): | |
| inserted = True | |
| index += 1 | |
| if isinstance(door_id, str): | |
| open_action = f"open({door_id})" | |
| if not _action_exists_before_index(quest_chain, open_action, index): | |
| if _insert_quest_step_before_index( | |
| quest_chain, | |
| index=index, | |
| step_id_base=f"open_{door_id}", | |
| description=f"Open the {_humanize_identifier(door_id).lower()}.", | |
| action=open_action, | |
| allow_existing_action=True, | |
| ): | |
| inserted = True | |
| index += 1 | |
| if inserted: | |
| step_by_id = { | |
| candidate.get("step_id"): candidate | |
| for candidate in quest_chain | |
| if isinstance(candidate, dict) | |
| and isinstance(candidate.get("step_id"), str) | |
| and candidate.get("step_id") | |
| } | |
| index += 1 | |
| def _select_synthetic_clue_gate_item_id(items: Any, quest_chain: Any) -> str | None: | |
| if not isinstance(items, list): | |
| return None | |
| taken_item_ids = _quest_taken_item_ids(quest_chain) | |
| prioritized: list[tuple[int, str]] = [] | |
| for item in items: | |
| if not isinstance(item, dict): | |
| continue | |
| item_id = item.get("id") | |
| subtype = item.get("subtype") | |
| if not isinstance(item_id, str) or not item_id: | |
| continue | |
| if subtype == "puzzle" and item_id in taken_item_ids: | |
| prioritized.append((0, item_id)) | |
| elif subtype == "puzzle" and item.get("start_node_id") is not None: | |
| prioritized.append((0, item_id)) | |
| elif subtype == "puzzle": | |
| prioritized.append((1, item_id)) | |
| elif subtype == "key": | |
| prioritized.append((2, item_id)) | |
| if not prioritized: | |
| return None | |
| prioritized.sort() | |
| return prioritized[0][1] | |
| def _humanize_identifier(identifier: str) -> str: | |
| return " ".join(part.capitalize() for part in identifier.split("_") if part) or identifier | |
| def _node_room_start_node_id(node: dict[str, Any], default_start_node_id: str | None) -> str | None: | |
| parent_id = node.get("parent_id") | |
| if isinstance(parent_id, str) and parent_id: | |
| return parent_id | |
| return default_start_node_id | |
| def _unique_world_id(base_id: str, existing_ids: set[str]) -> str: | |
| candidate = base_id | |
| suffix = 2 | |
| while candidate in existing_ids: | |
| candidate = f"{base_id}_{suffix}" | |
| suffix += 1 | |
| existing_ids.add(candidate) | |
| return candidate | |
| def _unique_world_label(base_label: str, existing_safe_labels: set[str]) -> str: | |
| candidate = base_label | |
| suffix = 2 | |
| while parser_safe_text(candidate) in existing_safe_labels: | |
| candidate = f"{base_label} {suffix}" | |
| suffix += 1 | |
| existing_safe_labels.add(parser_safe_text(candidate)) | |
| return candidate | |
| def _insert_quest_step_before_guardian_talk( | |
| quest_chain: Any, | |
| *, | |
| guardian_id: Any, | |
| step_id_base: str, | |
| description: str, | |
| action: str, | |
| ) -> str | None: | |
| if not isinstance(quest_chain, list): | |
| return None | |
| existing_step_ids = { | |
| step.get("step_id") | |
| for step in quest_chain | |
| if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id") | |
| } | |
| if any(isinstance(step, dict) and step.get("action") == action for step in quest_chain): | |
| return None | |
| talk_index: int | None = None | |
| talk_index = _guardian_talk_step_index(quest_chain, guardian_id=guardian_id) | |
| if talk_index is None: | |
| return None | |
| talk_step = quest_chain[talk_index] | |
| return _insert_quest_step_before_index( | |
| quest_chain, | |
| index=talk_index, | |
| step_id_base=step_id_base, | |
| description=description, | |
| action=action, | |
| ) | |
| def _quest_taken_item_ids(quest_chain: Any) -> set[str]: | |
| if not isinstance(quest_chain, list): | |
| return set() | |
| taken_item_ids: set[str] = set() | |
| for step in quest_chain: | |
| if not isinstance(step, dict): | |
| continue | |
| arguments = _extract_action_arguments(step.get("action"), "take") | |
| if arguments is None or not arguments: | |
| continue | |
| item_id = arguments[0] | |
| if item_id: | |
| taken_item_ids.add(item_id) | |
| return taken_item_ids | |
| def _infer_item_start_node_from_quest(quest_chain: Any, item_id: str) -> str | None: | |
| if not isinstance(quest_chain, list): | |
| return None | |
| for step in quest_chain: | |
| if not isinstance(step, dict): | |
| continue | |
| arguments = _extract_action_arguments(step.get("action"), "take") | |
| if arguments is None or len(arguments) != 2: | |
| continue | |
| if arguments[0] == item_id: | |
| return arguments[1] | |
| return None | |
| def _infer_guardian_talk_room_from_quest(quest_chain: Any, *, guardian_id: str) -> str | None: | |
| talk_index = _guardian_talk_step_index(quest_chain, guardian_id=guardian_id) | |
| if talk_index is None or not isinstance(quest_chain, list): | |
| return None | |
| for index in range(talk_index - 1, -1, -1): | |
| step = quest_chain[index] | |
| if not isinstance(step, dict): | |
| continue | |
| room_id = _extract_single_action_argument(step.get("action"), "go") | |
| if room_id: | |
| return room_id | |
| return None | |
| def _guardian_talk_step_index(quest_chain: Any, *, guardian_id: Any) -> int | None: | |
| if not isinstance(quest_chain, list) or not isinstance(guardian_id, str) or not guardian_id: | |
| return None | |
| for index, step in enumerate(quest_chain): | |
| if not isinstance(step, dict): | |
| continue | |
| target_id = _extract_single_action_argument(step.get("action"), "talk") | |
| if target_id == guardian_id: | |
| return index | |
| return None | |
| def _extract_action_arguments(action: Any, name: str) -> list[str] | None: | |
| if not isinstance(action, str): | |
| return None | |
| prefix = f"{name}(" | |
| if not action.startswith(prefix) or not action.endswith(")"): | |
| return None | |
| raw_arguments = action[len(prefix) : -1] | |
| arguments = [argument.strip().strip('"').strip("'") for argument in raw_arguments.split(",")] | |
| if any(not argument for argument in arguments): | |
| return None | |
| return arguments | |
| def _insert_quest_step_before_index( | |
| quest_chain: Any, | |
| *, | |
| index: int, | |
| step_id_base: str, | |
| description: str, | |
| action: str, | |
| allow_existing_action: bool = False, | |
| ) -> str | None: | |
| if not isinstance(quest_chain, list) or index < 0 or index >= len(quest_chain): | |
| return None | |
| current_step = quest_chain[index] | |
| if not isinstance(current_step, dict): | |
| return None | |
| if not allow_existing_action and any(isinstance(step, dict) and step.get("action") == action for step in quest_chain): | |
| return None | |
| existing_step_ids = { | |
| step.get("step_id") | |
| for step in quest_chain | |
| if isinstance(step, dict) and isinstance(step.get("step_id"), str) and step.get("step_id") | |
| } | |
| existing_requires = current_step.get("requires_step_ids") | |
| if isinstance(existing_requires, list): | |
| requires_step_ids = [step_id for step_id in existing_requires if isinstance(step_id, str) and step_id] | |
| else: | |
| requires_step_ids = [] | |
| new_step_id = _unique_world_id(step_id_base, existing_step_ids) | |
| quest_chain.insert( | |
| index, | |
| { | |
| "step_id": new_step_id, | |
| "description": description, | |
| "requires_step_ids": requires_step_ids, | |
| "action": action, | |
| }, | |
| ) | |
| current_step["requires_step_ids"] = [new_step_id] | |
| return new_step_id | |
| def _action_exists_before_index(quest_chain: Any, action: str, index: int) -> bool: | |
| if not isinstance(quest_chain, list): | |
| return False | |
| for current_step in quest_chain[:index]: | |
| if isinstance(current_step, dict) and current_step.get("action") == action: | |
| return True | |
| return False | |
| def _quest_required_item_ids(action: Any) -> list[str]: | |
| for name, count in (("use", 2), ("unlock", 2), ("give", 2), ("combine", 2)): | |
| arguments = _extract_action_arguments(action, name) | |
| if arguments is None: | |
| continue | |
| if name == "combine" and len(arguments) == count: | |
| return arguments | |
| if len(arguments) == count: | |
| return [arguments[0 if name != "unlock" else 1]] | |
| return [] | |
| def _infer_room_prereq_for_step(step: Any, step_by_id: dict[str, Any]) -> str | None: | |
| if not isinstance(step, dict): | |
| return None | |
| step_id = step.get("step_id") | |
| if isinstance(step_id, str) and step_id: | |
| inferred_room = _infer_step_terminal_room(step_id, step_by_id, set()) | |
| if inferred_room is not None: | |
| return inferred_room | |
| requires_step_ids = step.get("requires_step_ids") | |
| if not isinstance(requires_step_ids, list): | |
| return None | |
| room_id: str | None = None | |
| for dependency in requires_step_ids: | |
| if not isinstance(dependency, str): | |
| continue | |
| dependency_step = step_by_id.get(dependency) | |
| if not isinstance(dependency_step, dict): | |
| continue | |
| maybe_room_id = _extract_single_action_argument(dependency_step.get("action"), "go") | |
| if maybe_room_id: | |
| room_id = maybe_room_id | |
| return room_id | |
| def _infer_step_terminal_room(step_id: str, step_by_id: dict[str, Any], seen: set[str]) -> str | None: | |
| if step_id in seen: | |
| return None | |
| step = step_by_id.get(step_id) | |
| if not isinstance(step, dict): | |
| return None | |
| seen = set(seen) | |
| seen.add(step_id) | |
| target_room = _extract_single_action_argument(step.get("action"), "go") | |
| if target_room: | |
| return target_room | |
| requires_step_ids = step.get("requires_step_ids") | |
| if not isinstance(requires_step_ids, list): | |
| return None | |
| inferred_room: str | None = None | |
| for dependency in requires_step_ids: | |
| if not isinstance(dependency, str): | |
| continue | |
| dependency_room = _infer_step_terminal_room(dependency, step_by_id, seen) | |
| if dependency_room: | |
| inferred_room = dependency_room | |
| return inferred_room | |
| def _reachable_passage_room_ids(payload: dict[str, Any], *, start_node_id: Any) -> set[str]: | |
| if not isinstance(start_node_id, str) or not start_node_id: | |
| return set() | |
| edges = payload.get("edges") | |
| if not isinstance(edges, list): | |
| return {start_node_id} | |
| graph: dict[str, set[str]] = {} | |
| for edge in edges: | |
| if not isinstance(edge, dict) or edge.get("type") != "passage": | |
| continue | |
| from_node_id = edge.get("from_node_id") | |
| to_node_id = edge.get("to_node_id") | |
| if not isinstance(from_node_id, str) or not isinstance(to_node_id, str): | |
| continue | |
| graph.setdefault(from_node_id, set()).add(to_node_id) | |
| reachable = {start_node_id} | |
| frontier = [start_node_id] | |
| while frontier: | |
| current = frontier.pop() | |
| for nxt in graph.get(current, set()): | |
| if nxt in reachable: | |
| continue | |
| reachable.add(nxt) | |
| frontier.append(nxt) | |
| return reachable | |
| def _extract_single_action_argument(action: Any, name: str) -> str | None: | |
| if not isinstance(action, str): | |
| return None | |
| prefix = f"{name}(" | |
| if not action.startswith(prefix) or not action.endswith(")"): | |
| return None | |
| raw_argument = action[len(prefix) : -1].strip() | |
| if not raw_argument: | |
| return None | |
| if raw_argument[0] == raw_argument[-1] and raw_argument[0] in {'"', "'"}: | |
| raw_argument = raw_argument[1:-1] | |
| return raw_argument.strip() | |
| def _load_dm_world_definition(text: str, *, allow_repair: bool) -> WorldDefinition: | |
| payload = _try_parse_completion_json(text) | |
| if not isinstance(payload, dict): | |
| raise ValueError("Completion did not contain a JSON object.") | |
| if allow_repair: | |
| payload = _repair_dm_world_payload(payload) | |
| try: | |
| return WorldDefinition.model_validate(payload) | |
| except Exception: | |
| raise | |
| def _find_json_object_span(text: str) -> tuple[int, int] | None: | |
| start: int | None = None | |
| depth = 0 | |
| in_string = False | |
| escaped = False | |
| for index, character in enumerate(text): | |
| if in_string: | |
| if escaped: | |
| escaped = False | |
| elif character == "\\": | |
| escaped = True | |
| elif character == '"': | |
| in_string = False | |
| continue | |
| if character == '"': | |
| in_string = True | |
| continue | |
| if character == "{": | |
| if start is None: | |
| start = index | |
| depth += 1 | |
| continue | |
| if character == "}": | |
| if depth == 0: | |
| continue | |
| depth -= 1 | |
| if depth == 0 and start is not None: | |
| return start, index + 1 | |
| return None | |
| def _strip_code_fences(text: str) -> str: | |
| cleaned = text.strip() | |
| if cleaned.startswith("```"): | |
| lines = cleaned.splitlines() | |
| if lines and lines[0].startswith("```"): | |
| lines = lines[1:] | |
| if lines and lines[-1].strip() == "```": | |
| lines = lines[:-1] | |
| cleaned = "\n".join(lines).strip() | |
| return cleaned | |
| def _normalize_outer_completion_text(text: str) -> str: | |
| without_tools = _TOOL_CALL_RE.sub("", text) | |
| without_tools = _EMPTY_THINK_RE.sub("", without_tools) | |
| without_tools = _strip_code_fences(without_tools) | |
| return without_tools.strip() | |
| def _string_key_coverage(value: Any, keys: tuple[str, ...]) -> float: | |
| if not isinstance(value, dict): | |
| return 0.0 | |
| return sum(1 for key in keys if key in value) / len(keys) | |
| def _range_score(value: int, lower: int, upper: int) -> float: | |
| if lower <= value <= upper: | |
| return 1.0 | |
| if value < lower: | |
| return max(0.0, value / max(1, lower)) | |
| return max(0.0, 1.0 - ((value - upper) / max(1, upper))) | |
| def _compactness_score(length: int, target_max: int) -> float: | |
| if length <= target_max: | |
| return 1.0 | |
| overflow = length - target_max | |
| return max(0.0, 1.0 - (overflow / max(1, target_max))) | |
| def _dm_structural_prior_score(world: dict[str, Any], requested_ratio: float | None) -> float: | |
| meta = world.get("meta") | |
| nodes = world.get("nodes") if isinstance(world.get("nodes"), list) else [] | |
| edges = world.get("edges") if isinstance(world.get("edges"), list) else [] | |
| items = world.get("items") if isinstance(world.get("items"), list) else [] | |
| clues = world.get("clues") if isinstance(world.get("clues"), list) else [] | |
| recipes = world.get("recipes") if isinstance(world.get("recipes"), list) else [] | |
| quest_chain = world.get("quest_chain") if isinstance(world.get("quest_chain"), list) else [] | |
| components = [ | |
| (0.16, _string_key_coverage(world, _DM_REQUIRED_TOP_LEVEL_FIELDS)), | |
| (0.08, _string_key_coverage(meta, ("title", "difficulty_target", "start_node_id", "win_condition"))), | |
| (0.10, _dm_win_condition_score(meta)), | |
| (0.10, _range_score(len(nodes), 10, 16)), | |
| (0.07, _range_score(len(items), 5, 8)), | |
| (0.09, _range_score(len(clues), 3, 5)), | |
| (0.04, _range_score(len(recipes), 0, 1)), | |
| (0.10, _range_score(len(quest_chain), 12, 20)), | |
| (0.06, _valid_type_fraction(nodes, "type", _DM_ALLOWED_NODE_TYPES)), | |
| (0.04, _valid_type_fraction(edges, "type", _DM_ALLOWED_EDGE_TYPES)), | |
| (0.04, _valid_type_fraction(items, "subtype", _DM_ALLOWED_ITEM_TYPES)), | |
| (0.06, _compact_world_text_score(nodes, items, clues, quest_chain)), | |
| (0.06, _guardian_presence_score(meta, nodes)), | |
| ] | |
| if requested_ratio is not None: | |
| components.append((0.10, _difficulty_ratio_score(meta, requested_ratio))) | |
| weighted_total = sum(weight * score for weight, score in components) | |
| total_weight = sum(weight for weight, _ in components) | |
| return _clamp(weighted_total / max(1e-6, total_weight), 0.0, 1.0) | |
| def _dm_win_condition_score(meta: Any) -> float: | |
| if not isinstance(meta, dict): | |
| return 0.0 | |
| win_condition = meta.get("win_condition") | |
| if not isinstance(win_condition, dict): | |
| return 0.0 | |
| score = _string_key_coverage(win_condition, ("type", "target_npc_id", "answer_string")) | |
| if win_condition.get("type") == "deduce": | |
| score += 0.25 | |
| answer = win_condition.get("answer_string") | |
| if isinstance(answer, str) and _LOWERCASE_ANSWER_RE.fullmatch(answer): | |
| score += 0.25 | |
| return min(1.0, score) | |
| def _guardian_presence_score(meta: Any, nodes: list[Any]) -> float: | |
| if not isinstance(meta, dict): | |
| return 0.0 | |
| win_condition = meta.get("win_condition") | |
| if not isinstance(win_condition, dict): | |
| return 0.0 | |
| guardian_id = win_condition.get("target_npc_id") | |
| if not isinstance(guardian_id, str): | |
| return 0.0 | |
| return 1.0 if any(isinstance(node, dict) and node.get("type") == "npc" and node.get("id") == guardian_id for node in nodes) else 0.0 | |
| def _difficulty_ratio_score(meta: Any, requested_ratio: float) -> float: | |
| if not isinstance(meta, dict): | |
| return 0.0 | |
| try: | |
| actual_ratio = float(meta.get("difficulty_target")) | |
| except Exception: | |
| return 0.0 | |
| return max(0.0, 1.0 - abs(actual_ratio - requested_ratio)) | |
| def _valid_type_fraction(rows: list[Any], key: str, allowed_values: set[str]) -> float: | |
| typed_rows = [row for row in rows if isinstance(row, dict)] | |
| if not typed_rows: | |
| return 0.0 | |
| valid = sum(1 for row in typed_rows if row.get(key) in allowed_values) | |
| return valid / len(typed_rows) | |
| def _compact_world_text_score( | |
| nodes: list[Any], | |
| items: list[Any], | |
| clues: list[Any], | |
| quest_chain: list[Any], | |
| ) -> float: | |
| text_lengths: list[int] = [] | |
| for collection, keys in ( | |
| (nodes, ("label", "description")), | |
| (items, ("label", "description")), | |
| (clues, ("text",)), | |
| (quest_chain, ("description", "action")), | |
| ): | |
| for row in collection: | |
| if not isinstance(row, dict): | |
| continue | |
| for key in keys: | |
| value = row.get(key) | |
| if isinstance(value, str): | |
| text_lengths.append(len(value)) | |
| if not text_lengths: | |
| return 0.0 | |
| average_length = sum(text_lengths) / len(text_lengths) | |
| return _compactness_score(int(average_length), 80) | |
| def _validation_error_score(errors: list[dict[str, Any]]) -> float: | |
| if not errors: | |
| return 0.0 | |
| penalty = 0.0 | |
| for error in errors: | |
| error_type = str(error.get("type", "")) | |
| location = tuple(str(part) for part in error.get("loc", ())) | |
| field_name = location[-1] if location else "" | |
| if error_type == "extra_forbidden": | |
| penalty += 0.05 | |
| elif error_type.startswith("missing") and field_name in {"label", "description"}: | |
| penalty += 0.02 | |
| elif error_type.startswith("missing") and field_name == "text_content": | |
| penalty += 0.05 | |
| elif error_type.startswith("missing"): | |
| penalty += 0.06 | |
| else: | |
| penalty += 0.08 | |
| return _clamp(1.0 - penalty, 0.0, 1.0) | |
| def _compile_error_penalty(error_message: str) -> float: | |
| message = error_message.lower() | |
| if not message: | |
| return -0.5 | |
| if "between 3 and 5 clues" in message: | |
| return -0.35 | |
| if "duplicate world id" in message or "duplicate " in message: | |
| return -0.45 | |
| if "requires_step_id" in message or "requires_step_with" in message: | |
| return -0.45 | |
| if "requires requires_item_id" in message: | |
| return -0.50 | |
| if "must live in a location or junction" in message: | |
| return -0.55 | |
| if "fixture" in message and "requires unknown item" in message: | |
| return -0.60 | |
| if "unknown item" in message or "unknown clue" in message or "unknown node" in message: | |
| return -0.65 | |
| if "must reveal exactly one item or readable" in message: | |
| return -0.65 | |
| if "guardian npc cannot have trade fields" in message: | |
| return -0.70 | |
| if "unused decorative items" in message or "clue '" in message: | |
| return -0.75 | |
| if "final quest step" in message or "penultimate quest step" in message: | |
| return -0.80 | |
| if "unreachable" in message or "guardian room" in message: | |
| return -0.85 | |
| if "closed door" in message or "locked door" in message or "does not match key" in message: | |
| return -0.85 | |
| if "quest " in message or "unsupported quest action" in message: | |
| return -0.90 | |
| return -0.75 | |
| def _completion_tool_calls(completion: Any) -> list[dict[str, Any]]: | |
| return _extract_tool_calls_from_text(_completion_text(completion)) | |
| def _extract_tool_calls_from_text(text: str) -> list[dict[str, Any]]: | |
| tool_calls: list[dict[str, Any]] = [] | |
| for raw_payload in _TOOL_CALL_RE.findall(text): | |
| try: | |
| payload = json.loads(raw_payload) | |
| except Exception: | |
| continue | |
| normalized = _normalize_tool_call(payload, source="tool_call") | |
| if normalized is not None: | |
| tool_calls.append(normalized) | |
| if tool_calls: | |
| return tool_calls | |
| payload = _try_parse_completion_json(text) | |
| normalized = _normalize_tool_call(payload, source="json_action") | |
| if normalized is None: | |
| return [] | |
| return [normalized] | |
| def _normalize_tool_call(payload: Any, *, source: str) -> dict[str, Any] | None: | |
| if not isinstance(payload, dict): | |
| return None | |
| if payload.get("type") == "function" and isinstance(payload.get("function"), dict): | |
| payload = payload["function"] | |
| if isinstance(payload.get("name"), str): | |
| arguments = payload.get("arguments", {}) | |
| if not isinstance(arguments, dict): | |
| return None | |
| return {"name": payload["name"], "arguments": arguments, "source": source} | |
| action = payload.get("action") | |
| if isinstance(action, dict) and isinstance(action.get("tool"), str): | |
| arguments = {key: value for key, value in action.items() if key != "tool"} | |
| return {"name": action["tool"], "arguments": arguments, "source": source} | |
| return None | |
| def _hero_act_semantics_reward(arguments: Any) -> float: | |
| if not isinstance(arguments, dict): | |
| return 0.0 | |
| command = arguments.get("command") | |
| if not isinstance(command, str) or not command.strip(): | |
| return 0.0 | |
| normalized_command = command.strip().lower() | |
| parsed = parse_cli_command(command) | |
| if not parsed.valid: | |
| recovered = parse_cli_command(normalized_command) | |
| return 0.40 if recovered.valid else 0.0 | |
| return 1.0 if command == normalized_command else 0.85 | |
| def _hero_scratchpad_write_reward(arguments: Any) -> float: | |
| if not isinstance(arguments, dict): | |
| return 0.0 | |
| mode = arguments.get("mode") | |
| content = arguments.get("content") | |
| score = 0.0 | |
| if mode in {"append", "replace"}: | |
| score += 0.45 | |
| if isinstance(content, str) and content.strip(): | |
| score += 0.35 | |
| score += 0.20 * _compactness_score(len(content), 240) | |
| return min(1.0, score) | |
| def _clamp(value: float, lower: float, upper: float) -> float: | |
| return max(lower, min(upper, value)) | |
| def _require_training_dependencies() -> None: | |
| if TRAINING_IMPORT_ERROR is not None: | |
| raise RuntimeError( | |
| "Training dependencies are unavailable. Install the project with the training extras before using GRPO." | |
| ) from TRAINING_IMPORT_ERROR | |
| def _require_vllm_if_requested(config: GRPOLaunchConfig) -> None: | |
| if not config.use_vllm: | |
| return | |
| if importlib.util.find_spec("vllm") is None: | |
| raise RuntimeError( | |
| "vLLM is not installed but --use-vllm was requested. Install vllm in the training environment first." | |
| ) | |