File size: 5,244 Bytes
f748552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""MiniHack environment discovery and diagnostic utilities.

Provides tools for scanning the gymnasium registry, validating action-space
consistency across environments, and benchmarking inference throughput.
"""

from __future__ import annotations

import logging
import time
from types import SimpleNamespace

import torch

logger = logging.getLogger(__name__)

_NAV_KEYWORDS = ("Room", "Corridor", "Maze", "River")
_EXCLUDED_KEYWORDS = ("KeyRoom",)
_REFERENCE_ENV_ID = "MiniHack-Room-15x15-v0"


def list_working_minihack_tasks() -> list[str]:
    """Scan the gymnasium registry for working MiniHack navigation tasks.

    Filters to environments whose names contain at least one navigation
    keyword and attempts to instantiate each. Returns the IDs of all
    successfully created environments.

    Returns:
        Sorted list of working MiniHack navigation environment IDs.
    """
    import gymnasium as gym
    import minihack  # noqa: F401 — registers envs

    all_ids = list(gym.envs.registry.keys())
    candidates = [
        e for e in all_ids
        if "MiniHack" in e
        and any(k in e for k in _NAV_KEYWORDS)
        and not any(x in e for x in _EXCLUDED_KEYWORDS)
    ]

    working: list[str] = []
    broken: list[str] = []
    for env_id in sorted(candidates):
        try:
            env = gym.make(env_id)
            working.append(env_id)
            env.close()
        except Exception:
            broken.append(env_id)

    logger.info(
        f"MiniHack navigation tasks — working: {len(working)}, "
        f"broken: {len(broken)}"
    )
    return working


def check_action_consistency_with_fixed_ref(
    env_list: list[str],
) -> list[tuple[str, str, int]]:
    """Validate action-space ordering against a fixed reference environment.

    Compares each environment's action list against
    ``MiniHack-Room-15x15-v0`` and classifies the relationship as one of:
    ``REFERENCE``, ``EXACT``, ``SUPERSET (+N)``, ``SUBSET (-N)``,
    ``CONFLICT``, or ``CRASHED``.

    Args:
        env_list: MiniHack environment IDs to check.

    Returns:
        List of ``(env_id, status, action_space_size)`` tuples.
    """
    import gymnasium as gym
    import minihack  # noqa: F401

    ref_env = gym.make(_REFERENCE_ENV_ID)
    reference_actions = ref_env.unwrapped.actions  # type: ignore[attr-defined]
    ref_env.close()

    results: list[tuple[str, str, int]] = []
    for env_id in sorted(env_list):
        if env_id == _REFERENCE_ENV_ID:
            results.append((env_id, "REFERENCE", len(reference_actions)))
            continue
        try:
            env = gym.make(env_id)
            try:
                env_actions = env.unwrapped.actions  # type: ignore[attr-defined]
                limit = min(len(reference_actions), len(env_actions))
                is_match = all(
                    reference_actions[i] == env_actions[i]
                    for i in range(limit)
                )
                diff = len(env_actions) - len(reference_actions)
                if is_match and diff == 0:
                    status = "EXACT"
                elif diff > 0:
                    status = f"SUPERSET (+{diff})"
                elif is_match:
                    status = f"SUBSET ({diff})"
                else:
                    status = "CONFLICT"
                results.append((env_id, status, len(env_actions)))
            finally:
                env.close()
        except Exception:
            results.append((env_id, "CRASHED", 0))

    for name, status, size in results:
        logger.info(f"  {name:<40} | {status:<14} | n_actions={size}")
    return results


def benchmark_inference(
    model: torch.nn.Module,
    cfg: SimpleNamespace,
    device: torch.device | str,
    n_actions: int = 100,
) -> tuple[float, float]:
    """Measure ReMDM inference throughput.

    Runs ``n_actions`` planning calls with dummy observations and
    measures wall-clock time.

    Args:
        model: Denoising model in eval mode.
        cfg: Config namespace (used for ``seq_len``, ``mask_token``, etc.).
        device: Torch device.
        n_actions: Number of planning calls to benchmark.

    Returns:
        ``(diffusion_steps_per_sec, actions_per_sec)`` as floats.
    """
    from src.diffusion.sampling import remdm_sample

    model.eval()
    local_dummy = torch.zeros(
        (1, cfg.crop_size, cfg.crop_size), dtype=torch.long, device=device,
    )
    global_dummy = torch.zeros(
        (1, cfg.map_h, cfg.map_w), dtype=torch.long, device=device,
    )

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(n_actions):
        remdm_sample(model, local_dummy, global_dummy, cfg, device)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    elapsed = time.perf_counter() - t0

    total_steps = n_actions * cfg.diffusion_steps_eval
    steps_per_sec = total_steps / elapsed if elapsed > 0 else 0.0
    actions_per_sec = n_actions / elapsed if elapsed > 0 else 0.0

    logger.info(
        f"Benchmark ({n_actions} actions): "
        f"{steps_per_sec:.1f} diffusion-steps/s | "
        f"{actions_per_sec:.1f} actions/s"
    )
    return steps_per_sec, actions_per_sec