| """Stateless evaluation runner. |
| |
| Runs episodes using the diffusion model and collects per-environment |
| win rates, average rewards, and step counts. All episodes for a given |
| environment are rolled out in lockstep so that replanning calls are |
| batched into single GPU forward passes (B = n_episodes). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from types import SimpleNamespace |
|
|
| import numpy as np |
| import torch |
|
|
| from src.models.denoiser import ModelEMA, make_model |
| from src.planners.logging import Logger |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class Evaluator: |
| """Stateless evaluation runner. |
| |
| Runs the model on a set of environments and returns aggregate |
| statistics per environment. Episodes within each environment are |
| executed in lockstep so replanning calls are GPU-batched. |
| """ |
|
|
| @torch.no_grad() |
| def evaluate( |
| self, |
| env_ids: list[str], |
| model: torch.nn.Module, |
| n_episodes: int, |
| cfg: SimpleNamespace, |
| device: torch.device | str, |
| des_files: list[str] | None = None, |
| blind_global: bool = False, |
| ) -> dict[str, dict]: |
| """Evaluate *model* on each environment in *env_ids*. |
| |
| All *n_episodes* for a given environment run in lockstep so |
| that replanning forward passes are batched (B = active envs |
| needing a replan). |
| |
| Args: |
| env_ids: List of MiniHack environment IDs. |
| model: Denoising model (eval mode). |
| n_episodes: Episodes per environment. |
| cfg: Config namespace. |
| device: Torch device. |
| des_files: Optional list of ``.des`` file paths for custom |
| scenario evaluation. Each file yields one extra env entry |
| keyed by its filename stem. |
| blind_global: If ``True``, zero out global map observations |
| (local-only ablation mode). |
| |
| Returns: |
| ``{env_id: {"win_rate", "wins", "avg_reward", "avg_steps", |
| "n_episodes"}}`` |
| """ |
| model.eval() |
| results: dict[str, dict] = {} |
|
|
| |
| eval_targets: list[tuple[str, str | None]] = [ |
| (eid, None) for eid in env_ids |
| ] |
| if des_files: |
| for des_path in des_files: |
| from pathlib import Path |
| stem = Path(des_path).stem |
| with open(des_path) as fh: |
| eval_targets.append((stem, fh.read())) |
|
|
| for env_id, des_content in eval_targets: |
| seeds = [ |
| 42 + hash((env_id, ep)) % (2**31) |
| for ep in range(n_episodes) |
| ] |
| ep_results = self._run_episodes_batched( |
| model, env_id, n_episodes, cfg, device, |
| seeds=seeds, |
| des_content=des_content, |
| blind_global=blind_global, |
| ) |
|
|
| wins = sum(1 for r in ep_results if r["won"]) |
| total_reward = sum(r["total_reward"] for r in ep_results) |
| total_steps = sum(r["steps"] for r in ep_results) |
| n = max(len(ep_results), 1) |
| results[env_id] = { |
| "win_rate": wins / n, |
| "wins": wins, |
| "avg_reward": total_reward / n, |
| "avg_steps": total_steps / n, |
| "n_episodes": len(ep_results), |
| } |
|
|
| return results |
|
|
| @torch.no_grad() |
| def _run_episodes_batched( |
| self, |
| model: torch.nn.Module, |
| env_id: str, |
| n_episodes: int, |
| cfg: SimpleNamespace, |
| device: torch.device | str, |
| seeds: list[int], |
| des_content: str | None = None, |
| blind_global: bool = False, |
| ) -> list[dict]: |
| """Run episodes in lockstep with batched model inference. |
| |
| Creates one environment per episode, steps them in lockstep, |
| and batches all replanning calls into single forward passes |
| (B = number of active envs needing a replan at each step). |
| |
| Args: |
| model: Denoising model (eval mode). |
| env_id: MiniHack environment ID. |
| n_episodes: Number of episodes to run. |
| cfg: Config namespace. |
| device: Torch device. |
| seeds: Per-episode RNG seeds (length *n_episodes*). |
| des_content: Optional ``.des`` file content for custom |
| scenarios. |
| blind_global: If ``True``, zero out global map observations. |
| |
| Returns: |
| List of per-episode dicts with ``"won"``, ``"steps"``, |
| ``"total_reward"`` keys. Failed episodes report |
| ``won=False``. |
| """ |
| from src.diffusion.sampling import remdm_sample |
| from src.envs.minihack_env import make_env |
|
|
| n = n_episodes |
| max_steps = 500 |
| cs = cfg.crop_size |
|
|
| |
| envs: list = [] |
| cur_local = np.zeros((n, cs, cs), dtype=np.int16) |
| cur_global = np.zeros( |
| (n, cfg.map_h, cfg.map_w), dtype=np.int16, |
| ) |
| failed = np.zeros(n, dtype=bool) |
|
|
| for i in range(n): |
| try: |
| env = make_env(env_id, des_content, cfg) |
| (local, glb), _ = env.reset(seed=seeds[i]) |
| envs.append(env) |
| cur_local[i] = local |
| cur_global[i] = glb |
| except Exception: |
| logger.warning( |
| "Failed to create env %s (ep %d)", |
| env_id, i, exc_info=True, |
| ) |
| envs.append(None) |
| failed[i] = True |
|
|
| |
| plans = np.zeros((n, cfg.seq_len), dtype=np.int64) |
| step_in_plan = np.zeros(n, dtype=np.int32) |
| need_replan = np.ones(n, dtype=bool) |
| done = failed.copy() |
| won = np.zeros(n, dtype=bool) |
| total_reward = np.zeros(n, dtype=np.float64) |
| n_steps = np.zeros(n, dtype=np.int32) |
|
|
| try: |
| for _ in range(max_steps): |
| |
| replan_idx = np.where(need_replan & ~done)[0] |
| if len(replan_idx) > 0: |
| local_t = torch.from_numpy( |
| cur_local[replan_idx], |
| ).long().to(device) |
| glb_t = torch.from_numpy( |
| cur_global[replan_idx], |
| ).long().to(device) |
| batch_plans = remdm_sample( |
| model, local_t, glb_t, cfg, device, |
| physics_aware=getattr( |
| cfg, "physics_aware_sampling", False, |
| ), |
| blind_global=blind_global, |
| ).cpu().numpy() |
| plans[replan_idx] = batch_plans |
| step_in_plan[replan_idx] = 0 |
| need_replan[replan_idx] = False |
|
|
| |
| any_active = False |
| for i in range(n): |
| if done[i]: |
| continue |
| any_active = True |
|
|
| action = int(plans[i, step_in_plan[i]]) |
| action = max( |
| 0, min(action, cfg.action_dim - 1), |
| ) |
| step_in_plan[i] += 1 |
| n_steps[i] += 1 |
|
|
| if step_in_plan[i] >= cfg.replan_every: |
| need_replan[i] = True |
|
|
| try: |
| obs, reward, term, trunc, info = ( |
| envs[i].step(action) |
| ) |
| local, glb = obs |
| total_reward[i] += reward |
| cur_local[i] = local |
| cur_global[i] = glb |
|
|
| if info.get("won", False): |
| won[i] = True |
| if term or trunc: |
| done[i] = True |
| except Exception: |
| logger.warning( |
| "Episode %d step failed for %s", |
| i, env_id, exc_info=True, |
| ) |
| done[i] = True |
|
|
| if not any_active: |
| break |
| finally: |
| for env in envs: |
| if env is not None: |
| env.close() |
|
|
| return [ |
| { |
| "won": bool(won[i]), |
| "steps": int(n_steps[i]), |
| "total_reward": float(total_reward[i]), |
| } |
| for i in range(n) |
| ] |
|
|
|
|
| def format_eval_results( |
| results: dict[str, dict], label: str = "Eval", |
| ) -> str: |
| """Format evaluation results as an ASCII table. |
| |
| Args: |
| results: Output of ``Evaluator.evaluate``. |
| label: Table header label. |
| |
| Returns: |
| Formatted string. |
| """ |
| lines = [f"{'=' * 60}", f" {label} Results", f"{'=' * 60}"] |
| lines.append( |
| f" {'Environment':<35} {'WinRate':>8} {'Steps':>8}" |
| ) |
| lines.append(f" {'-' * 53}") |
| for env_id, stats in results.items(): |
| wr = f"{stats['win_rate']:.2%}" |
| st = f"{stats['avg_steps']:.1f}" |
| lines.append(f" {env_id:<35} {wr:>8} {st:>8}") |
| lines.append(f"{'=' * 60}") |
| return "\n".join(lines) |
|
|
|
|
| def save_eval_json( |
| results: dict, |
| path: str, |
| metadata: dict | None = None, |
| ) -> None: |
| """Save evaluation results to a JSON file. |
| |
| Args: |
| results: Evaluation results dict. |
| path: Output file path. |
| metadata: Optional extra metadata (e.g. iteration). |
| """ |
| payload = { |
| "timestamp": datetime.now(timezone.utc).isoformat(), |
| "results": results, |
| } |
| if metadata: |
| payload["metadata"] = metadata |
| resolved = str(Path(path).resolve()) |
| Path(resolved).parent.mkdir(parents=True, exist_ok=True) |
| try: |
| with open(resolved, "w") as f: |
| json.dump(payload, f, indent=2, default=str) |
| except Exception: |
| logger.error(f"Failed to save eval JSON to {resolved}", exc_info=True) |
|
|
|
|
| def run_inference( |
| cfg, |
| checkpoint_path: str, |
| env_ids: list[str] | None, |
| episodes: int, |
| output_path: str | None, |
| use_ema: bool, |
| log: Logger | None = None, |
| des_files: list[str] | None = None, |
| blind_global: bool = False, |
| ) -> None: |
| """Evaluate a checkpoint on specified environments.""" |
|
|
| device = cfg.device |
| logger.info(f"Inference on {device}") |
|
|
| model = make_model(cfg).to(device) |
| ckpt = torch.load( |
| checkpoint_path, map_location=device, weights_only=False, |
| ) |
|
|
| if "model_state_dict" in ckpt: |
| model.load_state_dict(ckpt["model_state_dict"]) |
| if use_ema and "ema_state_dict" in ckpt: |
| ema = ModelEMA(model, decay=cfg.ema_decay) |
| ema.load_state_dict(ckpt["ema_state_dict"]) |
| ema.apply_to(model) |
| else: |
| model.load_state_dict(ckpt) |
|
|
| model.eval() |
|
|
| if env_ids is None: |
| env_ids = cfg.id_envs + cfg.ood_envs |
|
|
| evaluator = Evaluator() |
| results = evaluator.evaluate( |
| env_ids, model, episodes, cfg, device, |
| des_files=des_files, blind_global=blind_global, |
| ) |
|
|
| print(format_eval_results(results, label="Inference")) |
|
|
| if log is not None: |
| log.log_eval(results, step=0, prefix="inference") |
| log.log_summary( |
| {f"inference/{env_id}/win_rate": stats["win_rate"] |
| for env_id, stats in results.items()} |
| ) |
|
|
| if output_path: |
| save_eval_json(results, output_path) |
| logger.info(f"Results saved to {output_path}") |
|
|