remdm-minihack / src /envs /discovery.py
MathisW78's picture
Demo notebook payload (source + checkpoint + assets)
f748552 verified
"""MiniHack environment discovery and diagnostic utilities.
Provides tools for scanning the gymnasium registry, validating action-space
consistency across environments, and benchmarking inference throughput.
"""
from __future__ import annotations
import logging
import time
from types import SimpleNamespace
import torch
logger = logging.getLogger(__name__)
_NAV_KEYWORDS = ("Room", "Corridor", "Maze", "River")
_EXCLUDED_KEYWORDS = ("KeyRoom",)
_REFERENCE_ENV_ID = "MiniHack-Room-15x15-v0"
def list_working_minihack_tasks() -> list[str]:
"""Scan the gymnasium registry for working MiniHack navigation tasks.
Filters to environments whose names contain at least one navigation
keyword and attempts to instantiate each. Returns the IDs of all
successfully created environments.
Returns:
Sorted list of working MiniHack navigation environment IDs.
"""
import gymnasium as gym
import minihack # noqa: F401 — registers envs
all_ids = list(gym.envs.registry.keys())
candidates = [
e for e in all_ids
if "MiniHack" in e
and any(k in e for k in _NAV_KEYWORDS)
and not any(x in e for x in _EXCLUDED_KEYWORDS)
]
working: list[str] = []
broken: list[str] = []
for env_id in sorted(candidates):
try:
env = gym.make(env_id)
working.append(env_id)
env.close()
except Exception:
broken.append(env_id)
logger.info(
f"MiniHack navigation tasks — working: {len(working)}, "
f"broken: {len(broken)}"
)
return working
def check_action_consistency_with_fixed_ref(
env_list: list[str],
) -> list[tuple[str, str, int]]:
"""Validate action-space ordering against a fixed reference environment.
Compares each environment's action list against
``MiniHack-Room-15x15-v0`` and classifies the relationship as one of:
``REFERENCE``, ``EXACT``, ``SUPERSET (+N)``, ``SUBSET (-N)``,
``CONFLICT``, or ``CRASHED``.
Args:
env_list: MiniHack environment IDs to check.
Returns:
List of ``(env_id, status, action_space_size)`` tuples.
"""
import gymnasium as gym
import minihack # noqa: F401
ref_env = gym.make(_REFERENCE_ENV_ID)
reference_actions = ref_env.unwrapped.actions # type: ignore[attr-defined]
ref_env.close()
results: list[tuple[str, str, int]] = []
for env_id in sorted(env_list):
if env_id == _REFERENCE_ENV_ID:
results.append((env_id, "REFERENCE", len(reference_actions)))
continue
try:
env = gym.make(env_id)
try:
env_actions = env.unwrapped.actions # type: ignore[attr-defined]
limit = min(len(reference_actions), len(env_actions))
is_match = all(
reference_actions[i] == env_actions[i]
for i in range(limit)
)
diff = len(env_actions) - len(reference_actions)
if is_match and diff == 0:
status = "EXACT"
elif diff > 0:
status = f"SUPERSET (+{diff})"
elif is_match:
status = f"SUBSET ({diff})"
else:
status = "CONFLICT"
results.append((env_id, status, len(env_actions)))
finally:
env.close()
except Exception:
results.append((env_id, "CRASHED", 0))
for name, status, size in results:
logger.info(f" {name:<40} | {status:<14} | n_actions={size}")
return results
def benchmark_inference(
model: torch.nn.Module,
cfg: SimpleNamespace,
device: torch.device | str,
n_actions: int = 100,
) -> tuple[float, float]:
"""Measure ReMDM inference throughput.
Runs ``n_actions`` planning calls with dummy observations and
measures wall-clock time.
Args:
model: Denoising model in eval mode.
cfg: Config namespace (used for ``seq_len``, ``mask_token``, etc.).
device: Torch device.
n_actions: Number of planning calls to benchmark.
Returns:
``(diffusion_steps_per_sec, actions_per_sec)`` as floats.
"""
from src.diffusion.sampling import remdm_sample
model.eval()
local_dummy = torch.zeros(
(1, cfg.crop_size, cfg.crop_size), dtype=torch.long, device=device,
)
global_dummy = torch.zeros(
(1, cfg.map_h, cfg.map_w), dtype=torch.long, device=device,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n_actions):
remdm_sample(model, local_dummy, global_dummy, cfg, device)
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
total_steps = n_actions * cfg.diffusion_steps_eval
steps_per_sec = total_steps / elapsed if elapsed > 0 else 0.0
actions_per_sec = n_actions / elapsed if elapsed > 0 else 0.0
logger.info(
f"Benchmark ({n_actions} actions): "
f"{steps_per_sec:.1f} diffusion-steps/s | "
f"{actions_per_sec:.1f} actions/s"
)
return steps_per_sec, actions_per_sec