| """MiniHack environment discovery and diagnostic utilities. |
| |
| Provides tools for scanning the gymnasium registry, validating action-space |
| consistency across environments, and benchmarking inference throughput. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import time |
| from types import SimpleNamespace |
|
|
| import torch |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _NAV_KEYWORDS = ("Room", "Corridor", "Maze", "River") |
| _EXCLUDED_KEYWORDS = ("KeyRoom",) |
| _REFERENCE_ENV_ID = "MiniHack-Room-15x15-v0" |
|
|
|
|
| def list_working_minihack_tasks() -> list[str]: |
| """Scan the gymnasium registry for working MiniHack navigation tasks. |
| |
| Filters to environments whose names contain at least one navigation |
| keyword and attempts to instantiate each. Returns the IDs of all |
| successfully created environments. |
| |
| Returns: |
| Sorted list of working MiniHack navigation environment IDs. |
| """ |
| import gymnasium as gym |
| import minihack |
|
|
| all_ids = list(gym.envs.registry.keys()) |
| candidates = [ |
| e for e in all_ids |
| if "MiniHack" in e |
| and any(k in e for k in _NAV_KEYWORDS) |
| and not any(x in e for x in _EXCLUDED_KEYWORDS) |
| ] |
|
|
| working: list[str] = [] |
| broken: list[str] = [] |
| for env_id in sorted(candidates): |
| try: |
| env = gym.make(env_id) |
| working.append(env_id) |
| env.close() |
| except Exception: |
| broken.append(env_id) |
|
|
| logger.info( |
| f"MiniHack navigation tasks — working: {len(working)}, " |
| f"broken: {len(broken)}" |
| ) |
| return working |
|
|
|
|
| def check_action_consistency_with_fixed_ref( |
| env_list: list[str], |
| ) -> list[tuple[str, str, int]]: |
| """Validate action-space ordering against a fixed reference environment. |
| |
| Compares each environment's action list against |
| ``MiniHack-Room-15x15-v0`` and classifies the relationship as one of: |
| ``REFERENCE``, ``EXACT``, ``SUPERSET (+N)``, ``SUBSET (-N)``, |
| ``CONFLICT``, or ``CRASHED``. |
| |
| Args: |
| env_list: MiniHack environment IDs to check. |
| |
| Returns: |
| List of ``(env_id, status, action_space_size)`` tuples. |
| """ |
| import gymnasium as gym |
| import minihack |
|
|
| ref_env = gym.make(_REFERENCE_ENV_ID) |
| reference_actions = ref_env.unwrapped.actions |
| ref_env.close() |
|
|
| results: list[tuple[str, str, int]] = [] |
| for env_id in sorted(env_list): |
| if env_id == _REFERENCE_ENV_ID: |
| results.append((env_id, "REFERENCE", len(reference_actions))) |
| continue |
| try: |
| env = gym.make(env_id) |
| try: |
| env_actions = env.unwrapped.actions |
| limit = min(len(reference_actions), len(env_actions)) |
| is_match = all( |
| reference_actions[i] == env_actions[i] |
| for i in range(limit) |
| ) |
| diff = len(env_actions) - len(reference_actions) |
| if is_match and diff == 0: |
| status = "EXACT" |
| elif diff > 0: |
| status = f"SUPERSET (+{diff})" |
| elif is_match: |
| status = f"SUBSET ({diff})" |
| else: |
| status = "CONFLICT" |
| results.append((env_id, status, len(env_actions))) |
| finally: |
| env.close() |
| except Exception: |
| results.append((env_id, "CRASHED", 0)) |
|
|
| for name, status, size in results: |
| logger.info(f" {name:<40} | {status:<14} | n_actions={size}") |
| return results |
|
|
|
|
| def benchmark_inference( |
| model: torch.nn.Module, |
| cfg: SimpleNamespace, |
| device: torch.device | str, |
| n_actions: int = 100, |
| ) -> tuple[float, float]: |
| """Measure ReMDM inference throughput. |
| |
| Runs ``n_actions`` planning calls with dummy observations and |
| measures wall-clock time. |
| |
| Args: |
| model: Denoising model in eval mode. |
| cfg: Config namespace (used for ``seq_len``, ``mask_token``, etc.). |
| device: Torch device. |
| n_actions: Number of planning calls to benchmark. |
| |
| Returns: |
| ``(diffusion_steps_per_sec, actions_per_sec)`` as floats. |
| """ |
| from src.diffusion.sampling import remdm_sample |
|
|
| model.eval() |
| local_dummy = torch.zeros( |
| (1, cfg.crop_size, cfg.crop_size), dtype=torch.long, device=device, |
| ) |
| global_dummy = torch.zeros( |
| (1, cfg.map_h, cfg.map_w), dtype=torch.long, device=device, |
| ) |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
|
|
| t0 = time.perf_counter() |
| for _ in range(n_actions): |
| remdm_sample(model, local_dummy, global_dummy, cfg, device) |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| elapsed = time.perf_counter() - t0 |
|
|
| total_steps = n_actions * cfg.diffusion_steps_eval |
| steps_per_sec = total_steps / elapsed if elapsed > 0 else 0.0 |
| actions_per_sec = n_actions / elapsed if elapsed > 0 else 0.0 |
|
|
| logger.info( |
| f"Benchmark ({n_actions} actions): " |
| f"{steps_per_sec:.1f} diffusion-steps/s | " |
| f"{actions_per_sec:.1f} actions/s" |
| ) |
| return steps_per_sec, actions_per_sec |
|
|