| """SB3 + Decision Transformer baselines for the ReMDM diffusion planner. |
| |
| This module wraps standard discrete-action RL baselines (PPO, A2C, DQN, |
| recurrent PPO) plus two imitation baselines (Behavioural Cloning and |
| Decision Transformer) into the project's unified config + dispatch |
| surface so they can be compared head-to-head against the DAgger / |
| offline-BC diffusion planner on the same MiniHack environments. |
| |
| Entry point: :func:`run_baselines`. |
| |
| Hyperparameters live in ``configs/defaults.yaml`` under the |
| ``baselines_*`` namespace; the unified env-step training budget |
| (``cfg.total_timesteps``) is shared with DAgger and offline BC. |
| |
| W&B logging routes through the project's :class:`Logger` (with the W&B |
| project temporarily swapped to ``cfg.baselines_wandb_project``); SB3's |
| standard ``WandbCallback`` piggybacks on the active run and syncs its |
| tensorboard scalars automatically. No file in this module calls |
| ``wandb.log(...)`` directly. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import os |
| import random |
| from pathlib import Path |
| from types import SimpleNamespace |
| from typing import Any |
|
|
| import gymnasium as gym |
| import numpy as np |
| import orjson |
| import torch |
| import torch.nn as nn |
| from sb3_contrib import RecurrentPPO |
| from stable_baselines3 import A2C, DQN, PPO |
| from stable_baselines3.common.callbacks import CallbackList, EvalCallback |
| from stable_baselines3.common.monitor import Monitor |
| from stable_baselines3.common.policies import ActorCriticPolicy |
| from stable_baselines3.common.torch_layers import BaseFeaturesExtractor |
| from stable_baselines3.common.vec_env import SubprocVecEnv |
| from torch.utils.data import DataLoader, Dataset |
| from wandb.integration.sb3 import WandbCallback |
|
|
| from src.envs.minihack_env import ( |
| AdvancedObservationEnv, |
| collect_oracle_trajectory, |
| ) |
| from src.planners.logging import Logger |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| SB3_RL_ALGOS: tuple[str, ...] = ("ppo", "a2c", "dqn", "ppo-rnn") |
| IMITATION_ALGOS: tuple[str, ...] = ("bc", "dt") |
| ALL_BASELINE_ALGOS: tuple[str, ...] = SB3_RL_ALGOS + IMITATION_ALGOS |
|
|
|
|
| |
| |
| |
|
|
|
|
| class _SB3MiniHackWrapper(gym.Wrapper): |
| """Reshape ``AdvancedObservationEnv`` tuple obs into an SB3 dict obs. |
| |
| The underlying env returns ``(local_crop, global_map)`` with shapes |
| ``(crop, crop)`` and ``(map_h, map_w)``; SB3's ``MultiInputPolicy`` |
| needs a ``Dict`` space with explicit channel dims. Also remaps |
| ``info["won"]`` -> ``info["is_success"]`` so SB3's success tracking |
| reports our win rate. |
| """ |
|
|
| def __init__(self, env: AdvancedObservationEnv) -> None: |
| super().__init__(env) |
| local_h, local_w = env.observation_space.shape |
| cfg = env._cfg |
| self.observation_space = gym.spaces.Dict( |
| { |
| "local": gym.spaces.Box( |
| low=0, high=6000, shape=(1, local_h, local_w), dtype=np.int16, |
| ), |
| "global": gym.spaces.Box( |
| low=0, high=6000, shape=(1, cfg.map_h, cfg.map_w), dtype=np.int16, |
| ), |
| } |
| ) |
|
|
| def reset(self, **kwargs: Any) -> tuple[dict[str, np.ndarray], dict]: |
| (local, glob), info = self.env.reset(**kwargs) |
| return self._pack(local, glob), info |
|
|
| def step( |
| self, action: int, |
| ) -> tuple[dict[str, np.ndarray], float, bool, bool, dict]: |
| (local, glob), reward, terminated, truncated, info = self.env.step(action) |
| if "won" in info: |
| info["is_success"] = info["won"] |
| return self._pack(local, glob), reward, terminated, truncated, info |
|
|
| @staticmethod |
| def _pack( |
| local: np.ndarray, glob: np.ndarray, |
| ) -> dict[str, np.ndarray]: |
| return { |
| "local": np.expand_dims(local, axis=0), |
| "global": np.expand_dims(glob, axis=0), |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| class _MiniHackCNN(BaseFeaturesExtractor): |
| """Dual-stream CNN for the SB3 dict observation. |
| |
| Local stream: ``Conv(1->16, 3) -> Conv(16->32, 3)``. |
| Global stream: ``Conv(1->16, 5, stride 2) -> Conv(16->32, 3, stride 2)``. |
| Both streams are flattened and concatenated, then projected to |
| ``features_dim`` via a single linear + ReLU. |
| """ |
|
|
| def __init__( |
| self, observation_space: gym.spaces.Dict, features_dim: int = 256, |
| ) -> None: |
| super().__init__(observation_space, features_dim) |
| self.local_cnn = nn.Sequential( |
| nn.Conv2d(1, 16, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(16, 32, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Flatten(), |
| ) |
| self.global_cnn = nn.Sequential( |
| nn.Conv2d(1, 16, kernel_size=5, stride=2), |
| nn.ReLU(), |
| nn.Conv2d(16, 32, kernel_size=3, stride=2), |
| nn.ReLU(), |
| nn.Flatten(), |
| ) |
| with torch.no_grad(): |
| dummy_loc = torch.zeros(1, *observation_space["local"].shape) |
| dummy_glob = torch.zeros(1, *observation_space["global"].shape) |
| n_flatten = ( |
| self.local_cnn(dummy_loc).shape[1] |
| + self.global_cnn(dummy_glob).shape[1] |
| ) |
| self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) |
|
|
| def forward( |
| self, observations: dict[str, torch.Tensor], |
| ) -> torch.Tensor: |
| loc = self.local_cnn(observations["local"].float()) |
| glob = self.global_cnn(observations["global"].float()) |
| return self.linear(torch.cat([loc, glob], dim=1)) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class _MiniHackStateEncoder(nn.Module): |
| """CNN encoder mapping a (local, global) obs pair to a state embedding.""" |
|
|
| def __init__( |
| self, |
| embed_dim: int = 128, |
| crop_h: int = 9, |
| crop_w: int = 9, |
| map_h: int = 21, |
| map_w: int = 79, |
| ) -> None: |
| super().__init__() |
| self.local_cnn = nn.Sequential( |
| nn.Conv2d(1, 16, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(16, 32, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Flatten(), |
| ) |
| self.global_cnn = nn.Sequential( |
| nn.Conv2d(1, 16, kernel_size=5, stride=2), |
| nn.ReLU(), |
| nn.Conv2d(16, 32, kernel_size=3, stride=2), |
| nn.ReLU(), |
| nn.Flatten(), |
| ) |
| with torch.no_grad(): |
| dummy_loc = torch.zeros(1, 1, crop_h, crop_w) |
| dummy_glob = torch.zeros(1, 1, map_h, map_w) |
| local_flat = self.local_cnn(dummy_loc).shape[1] |
| global_flat = self.global_cnn(dummy_glob).shape[1] |
| self.proj = nn.Linear(local_flat + global_flat, embed_dim) |
|
|
| def forward( |
| self, local_obs: torch.Tensor, global_obs: torch.Tensor, |
| ) -> torch.Tensor: |
| |
| if local_obs.dim() == 5: |
| B, T = local_obs.shape[:2] |
| local_obs = local_obs.view(B * T, *local_obs.shape[2:]) |
| global_obs = global_obs.view(B * T, *global_obs.shape[2:]) |
| reshape = True |
| else: |
| B, T = local_obs.shape[0], 1 |
| reshape = False |
|
|
| loc_feat = self.local_cnn(local_obs.float()) |
| glob_feat = self.global_cnn(global_obs.float()) |
| out = self.proj(torch.cat([loc_feat, glob_feat], dim=-1)) |
| if reshape: |
| out = out.view(B, T, -1) |
| return out |
|
|
|
|
| class _DecisionTransformer(nn.Module): |
| """Causal Decision Transformer over interleaved (R, s, a) tokens.""" |
|
|
| def __init__( |
| self, |
| n_actions: int, |
| embed_dim: int = 128, |
| n_heads: int = 4, |
| n_layers: int = 3, |
| context_len: int = 30, |
| max_ep_len: int = 500, |
| dropout: float = 0.1, |
| crop_h: int = 9, |
| crop_w: int = 9, |
| map_h: int = 21, |
| map_w: int = 79, |
| ) -> None: |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.context_len = context_len |
| self.n_actions = n_actions |
| self.max_ep_len = max_ep_len |
|
|
| self.state_encoder = _MiniHackStateEncoder( |
| embed_dim, crop_h, crop_w, map_h, map_w, |
| ) |
| self.action_embed = nn.Embedding(n_actions + 1, embed_dim) |
| self.return_embed = nn.Linear(1, embed_dim) |
| self.pos_embed = nn.Embedding(max_ep_len, embed_dim) |
| self.token_type_embed = nn.Embedding(3, embed_dim) |
| self.embed_ln = nn.LayerNorm(embed_dim) |
| self.dropout = nn.Dropout(dropout) |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=embed_dim, |
| nhead=n_heads, |
| dim_feedforward=embed_dim * 4, |
| dropout=dropout, |
| activation="gelu", |
| batch_first=True, |
| ) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
| self.action_head = nn.Linear(embed_dim, n_actions) |
|
|
| self.apply(self._init_weights) |
|
|
| @staticmethod |
| def _init_weights(module: nn.Module) -> None: |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| def forward( |
| self, |
| returns_to_go: torch.Tensor, |
| local_obs: torch.Tensor, |
| global_obs: torch.Tensor, |
| actions: torch.Tensor, |
| timesteps: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| B, T = returns_to_go.shape[:2] |
| device = returns_to_go.device |
|
|
| rtg_embed = self.return_embed(returns_to_go) |
| state_embed = self.state_encoder(local_obs, global_obs) |
| action_embed = self.action_embed(actions) |
|
|
| pos_embed = self.pos_embed(timesteps) |
| rtg_embed = rtg_embed + pos_embed + self.token_type_embed.weight[0] |
| state_embed = state_embed + pos_embed + self.token_type_embed.weight[1] |
| action_embed = action_embed + pos_embed + self.token_type_embed.weight[2] |
|
|
| |
| stacked = torch.stack([rtg_embed, state_embed, action_embed], dim=2) |
| stacked = stacked.view(B, 3 * T, self.embed_dim) |
| stacked = self.dropout(self.embed_ln(stacked)) |
|
|
| seq_len = 3 * T |
| causal_mask = torch.triu( |
| torch.ones(seq_len, seq_len, device=device), diagonal=1, |
| ).bool() |
|
|
| key_padding_mask = None |
| if attention_mask is not None: |
| expanded = attention_mask.unsqueeze(-1).repeat(1, 1, 3).view(B, 3 * T) |
| key_padding_mask = expanded == 0 |
|
|
| hidden = self.transformer( |
| stacked, mask=causal_mask, src_key_padding_mask=key_padding_mask, |
| ) |
| |
| state_hidden = hidden[:, 1::3, :] |
| return self.action_head(state_hidden) |
|
|
| @torch.no_grad() |
| def get_action( |
| self, |
| returns_to_go: torch.Tensor, |
| local_obs: torch.Tensor, |
| global_obs: torch.Tensor, |
| actions: torch.Tensor, |
| timesteps: torch.Tensor, |
| ) -> torch.Tensor: |
| self.eval() |
| logits = self.forward( |
| returns_to_go, local_obs, global_obs, actions, timesteps, |
| ) |
| return logits[:, -1, :].argmax(dim=-1) |
|
|
|
|
| class _DTDataset(Dataset): |
| """Sliding-window dataset over Decision Transformer trajectories.""" |
|
|
| def __init__( |
| self, |
| trajectories: list[dict[str, np.ndarray]], |
| context_len: int, |
| max_ep_len: int, |
| n_actions: int, |
| ) -> None: |
| self.trajectories = trajectories |
| self.context_len = context_len |
| self.max_ep_len = max_ep_len |
| self.n_actions = n_actions |
| self.indices: list[tuple[int, int]] = [ |
| (traj_idx, start) |
| for traj_idx, traj in enumerate(trajectories) |
| for start in range(len(traj["actions"])) |
| ] |
|
|
| def __len__(self) -> int: |
| return len(self.indices) |
|
|
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: |
| traj_idx, start = self.indices[idx] |
| traj = self.trajectories[traj_idx] |
| traj_len = len(traj["actions"]) |
| end = min(start + self.context_len, traj_len) |
| actual_len = end - start |
|
|
| local = traj["local"][start:end].copy() |
| glob = traj["global"][start:end].copy() |
| actions = traj["actions"][start:end].copy() |
| rtg = traj["returns_to_go"][start:end].copy() |
| timesteps = np.arange(start, end) |
|
|
| |
| timesteps = np.clip(timesteps, 0, self.max_ep_len - 1) |
| actions = np.clip(actions, 0, self.n_actions - 1) |
|
|
| pad_len = self.context_len - actual_len |
| if pad_len > 0: |
| local = np.pad( |
| local, ((0, pad_len), (0, 0), (0, 0), (0, 0)), mode="constant", |
| ) |
| glob = np.pad( |
| glob, ((0, pad_len), (0, 0), (0, 0), (0, 0)), mode="constant", |
| ) |
| actions = np.pad(actions, (0, pad_len), mode="constant") |
| rtg = np.pad(rtg, (0, pad_len), mode="constant") |
| timesteps = np.pad(timesteps, (0, pad_len), mode="constant") |
|
|
| attention_mask = np.zeros(self.context_len, dtype=np.float32) |
| attention_mask[:actual_len] = 1.0 |
|
|
| return { |
| "local": torch.tensor(local, dtype=torch.float32), |
| "global": torch.tensor(glob, dtype=torch.float32), |
| "actions": torch.tensor(actions, dtype=torch.long), |
| "returns_to_go": torch.tensor(rtg, dtype=torch.float32).unsqueeze(-1), |
| "timesteps": torch.tensor(timesteps, dtype=torch.long), |
| "attention_mask": torch.tensor(attention_mask, dtype=torch.float32), |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| class _PrefixedEvalCallback(EvalCallback): |
| """``EvalCallback`` that records mean_reward / avg_steps / win_rate |
| under a unique per-environment prefix. |
| |
| SB3 truncates metric names at 36 chars, which collides on long |
| MiniHack env IDs; the prefix lets us strip ``MiniHack-`` / ``-v0`` |
| cleanly. |
| """ |
|
|
| def __init__( |
| self, eval_env: SubprocVecEnv, prefix: str, **kwargs: Any, |
| ) -> None: |
| super().__init__(eval_env, **kwargs) |
| self.prefix = prefix |
|
|
| def _on_step(self) -> bool: |
| cont = super()._on_step() |
| if self.evaluations_results: |
| self.logger.record( |
| f"{self.prefix}/mean_reward", float(np.mean(self.evaluations_results[-1])), |
| ) |
| self.logger.record( |
| f"{self.prefix}/avg_steps", float(np.mean(self.evaluations_length[-1])), |
| ) |
| if self.evaluations_successes: |
| self.logger.record( |
| f"{self.prefix}/win_rate", |
| float(np.mean(self.evaluations_successes[-1])), |
| ) |
| return cont |
|
|
|
|
| def _make_sb3_env_fn(env_id: str, cfg: SimpleNamespace, log_dir: str): |
| """Return a picklable thunk that builds one wrapped+monitored env.""" |
|
|
| def _init() -> Monitor: |
| os.makedirs(log_dir, exist_ok=True) |
| env = AdvancedObservationEnv(env_id, des_file=None, cfg=cfg) |
| env = _SB3MiniHackWrapper(env) |
| return Monitor(env, log_dir) |
|
|
| return _init |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _short(env_id: str) -> str: |
| return env_id.replace("MiniHack-", "").replace("-v0", "") |
|
|
|
|
| def _eval_episodes_per_env(cfg: SimpleNamespace) -> int: |
| override = getattr(cfg, "baselines_eval_episodes_per_env", None) |
| if override is not None: |
| return int(override) |
| return int(cfg.eval_episodes_per_env) |
|
|
|
|
| def _seed_everything(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def _resolve_output_dir(cfg: SimpleNamespace, override: str | None) -> Path: |
| if override: |
| out = Path(override) |
| else: |
| out = Path(cfg.baselines_output_dir) |
| out.mkdir(parents=True, exist_ok=True) |
| return out |
|
|
|
|
| def _init_baseline_logger( |
| cfg: SimpleNamespace, run_name: str, |
| ) -> Logger: |
| """Init the project Logger with W&B project swapped to baselines. |
| |
| Mutates ``cfg.wandb_project`` / ``cfg.wandb_run_name`` / |
| ``cfg.wandb_resume_id`` for the duration of the call so the existing |
| Logger constructor picks them up. We deliberately do not restore the |
| originals — each baseline seed reuses this helper, and main.py exits |
| after ``run_baselines`` returns. |
| """ |
|
|
| project_override = getattr(cfg, "baselines_wandb_project", None) |
| if project_override: |
| cfg.wandb_project = project_override |
| cfg.wandb_run_name = run_name |
| cfg.wandb_resume_id = None |
| return Logger(cfg) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _collect_bc_dataset( |
| cfg: SimpleNamespace, |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| """Roll out the BFS oracle on each ID env and stack flat (s, a) pairs.""" |
|
|
| n_per_env = int(cfg.baselines_bc_oracle_episodes_per_env) |
| locals_, globals_, actions_ = [], [], [] |
| for env_id in cfg.id_envs: |
| for traj_seed in range(n_per_env): |
| traj = collect_oracle_trajectory(env_id, traj_seed, cfg) |
| if traj is None: |
| continue |
| |
| locals_.append(np.expand_dims(traj["local"], axis=1)) |
| globals_.append(np.expand_dims(traj["global"], axis=1)) |
| actions_.append(traj["actions"]) |
| if not actions_: |
| raise RuntimeError("BC oracle collection produced zero trajectories") |
| return ( |
| np.concatenate(locals_, axis=0), |
| np.concatenate(globals_, axis=0), |
| np.concatenate(actions_, axis=0), |
| ) |
|
|
|
|
| class _BCDataset(Dataset): |
| def __init__( |
| self, loc: np.ndarray, glob: np.ndarray, acts: np.ndarray, |
| ) -> None: |
| self.loc = torch.tensor(loc, dtype=torch.float32) |
| self.glob = torch.tensor(glob, dtype=torch.float32) |
| self.acts = torch.tensor(acts, dtype=torch.int64) |
|
|
| def __len__(self) -> int: |
| return len(self.acts) |
|
|
| def __getitem__( |
| self, idx: int, |
| ) -> dict[str, dict[str, torch.Tensor] | torch.Tensor]: |
| return { |
| "obs": {"local": self.loc[idx], "global": self.glob[idx]}, |
| "acts": self.acts[idx], |
| } |
|
|
|
|
| def _eval_sb3_policy_manually( |
| policy: ActorCriticPolicy, |
| env_id: str, |
| cfg: SimpleNamespace, |
| log_dir: str, |
| n_episodes: int, |
| ) -> tuple[float, float]: |
| """Run ``policy.predict`` on a Monitor-wrapped vec env and return |
| (win_rate, avg_steps).""" |
|
|
| eval_env = SubprocVecEnv([_make_sb3_env_fn(env_id, cfg, log_dir)]) |
| try: |
| obs = eval_env.reset() |
| wins = 0 |
| total_steps = 0 |
| completed = 0 |
| while completed < n_episodes: |
| action, _ = policy.predict(obs, deterministic=True) |
| obs, _rewards, dones, infos = eval_env.step(action) |
| if dones[0]: |
| completed += 1 |
| if infos[0].get("won", False): |
| wins += 1 |
| total_steps += infos[0]["episode"]["l"] |
| finally: |
| eval_env.close() |
| return wins / n_episodes, total_steps / n_episodes |
|
|
|
|
| def _train_bc( |
| cfg: SimpleNamespace, |
| train_env: SubprocVecEnv, |
| log: Logger, |
| log_dir: str, |
| seed: int, |
| ) -> tuple[ActorCriticPolicy, dict[str, float]]: |
| """Train a Behavioural Cloning baseline. Returns (policy, seed_metrics).""" |
|
|
| device = torch.device(cfg.device) |
| n_eval = _eval_episodes_per_env(cfg) |
|
|
| logger.info("Collecting oracle demonstrations for BC...") |
| loc_arr, glob_arr, acts_arr = _collect_bc_dataset(cfg) |
| logger.info("BC dataset: %d transitions", len(acts_arr)) |
|
|
| bc_loader = DataLoader( |
| _BCDataset(loc_arr, glob_arr, acts_arr), |
| batch_size=int(cfg.baselines_bc_batch_size), |
| shuffle=True, |
| num_workers=4, |
| pin_memory=torch.cuda.is_available(), |
| ) |
|
|
| lr = float(cfg.baselines_bc_lr) |
| policy = ActorCriticPolicy( |
| observation_space=train_env.observation_space, |
| action_space=train_env.action_space, |
| lr_schedule=lambda _progress: lr, |
| features_extractor_class=_MiniHackCNN, |
| features_extractor_kwargs={"features_dim": 256}, |
| ).to(device) |
|
|
| n_epochs = int(cfg.baselines_bc_epochs) |
| optimizer = torch.optim.AdamW( |
| policy.parameters(), |
| lr=lr, |
| weight_decay=float(cfg.weight_decay), |
| ) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=n_epochs, |
| ) |
| policy.train() |
| for epoch in range(n_epochs): |
| total_loss = 0.0 |
| for batch in bc_loader: |
| obs = {k: v.to(policy.device) for k, v in batch["obs"].items()} |
| acts = batch["acts"].to(policy.device) |
| _values, log_prob, _entropy = policy.evaluate_actions(obs, acts) |
| loss = -log_prob.mean() |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0) |
| optimizer.step() |
| total_loss += loss.item() |
| scheduler.step() |
| avg_loss = total_loss / max(1, len(bc_loader)) |
| current_lr = scheduler.get_last_lr()[0] |
| log.log( |
| { |
| "train/bc_loss": avg_loss, |
| "train/lr": current_lr, |
| "train/epoch": epoch + 1, |
| }, |
| step=epoch + 1, |
| ) |
| logger.info( |
| "BC epoch %02d/%02d | loss=%.4f | lr=%.2e", |
| epoch + 1, n_epochs, avg_loss, current_lr, |
| ) |
|
|
| seed_metrics: dict[str, float] = {} |
| for split, env_list in (("ID", cfg.id_envs), ("OOD", cfg.ood_envs)): |
| logger.info("--- BC %s evaluation (seed=%d) ---", split, seed) |
| for env_id in env_list: |
| short = _short(env_id) |
| win_rate, avg_steps = _eval_sb3_policy_manually( |
| policy, |
| env_id, |
| cfg, |
| f"{log_dir}/eval_{split.lower()}/{env_id}", |
| n_eval, |
| ) |
| seed_metrics[f"{split}/{short}/win_rate"] = win_rate * 100 |
| seed_metrics[f"{split}/{short}/avg_steps"] = avg_steps |
| logger.info( |
| "%-30s | win_rate=%5.1f%% | avg_steps=%5.1f", |
| short, win_rate * 100, avg_steps, |
| ) |
| log.log(seed_metrics, step=n_epochs + 1) |
| return policy, seed_metrics |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _collect_dt_trajectories( |
| cfg: SimpleNamespace, |
| ) -> list[dict[str, np.ndarray]]: |
| """Collect oracle trajectories with sparse reward + return-to-go labels.""" |
|
|
| n_per_env = int(cfg.baselines_dt_oracle_episodes_per_env) |
| trajectories: list[dict[str, np.ndarray]] = [] |
| for env_id in cfg.id_envs: |
| for traj_seed in range(n_per_env): |
| traj = collect_oracle_trajectory(env_id, traj_seed, cfg) |
| if traj is None: |
| continue |
| T = len(traj["actions"]) |
| rewards = np.zeros(T, dtype=np.float32) |
| rewards[-1] = 1.0 |
| rtg = np.zeros(T, dtype=np.float32) |
| rtg[-1] = rewards[-1] |
| for t in range(T - 2, -1, -1): |
| rtg[t] = rewards[t] + rtg[t + 1] |
| trajectories.append( |
| { |
| "local": np.expand_dims(traj["local"], axis=1), |
| "global": np.expand_dims(traj["global"], axis=1), |
| "actions": traj["actions"], |
| "rewards": rewards, |
| "returns_to_go": rtg, |
| } |
| ) |
| return trajectories |
|
|
|
|
| def _eval_dt( |
| model: _DecisionTransformer, |
| env_id: str, |
| cfg: SimpleNamespace, |
| target_return: float, |
| n_episodes: int, |
| max_ep_len: int, |
| eval_max_steps: int, |
| context_len: int, |
| ) -> tuple[float, float]: |
| """Roll out a trained Decision Transformer with target-return conditioning.""" |
|
|
| device = torch.device(cfg.device) |
| env = AdvancedObservationEnv(env_id, des_file=None, cfg=cfg) |
| env = _SB3MiniHackWrapper(env) |
| model.eval() |
| wins = 0 |
| total_steps = 0 |
| try: |
| for _ep in range(n_episodes): |
| obs, _ = env.reset() |
| done = False |
|
|
| local_hist: list[np.ndarray] = [] |
| global_hist: list[np.ndarray] = [] |
| action_hist: list[int] = [] |
| rtg_hist: list[float] = [] |
| ts_hist: list[int] = [] |
|
|
| current_rtg = float(target_return) |
| t = 0 |
| info: dict = {} |
| while not done and t < eval_max_steps: |
| local_hist.append(obs["local"]) |
| global_hist.append(obs["global"]) |
| rtg_hist.append(current_rtg) |
| ts_hist.append(min(t, max_ep_len - 1)) |
|
|
| ctx = min(len(local_hist), context_len) |
| local_in = np.stack(local_hist[-ctx:], axis=0) |
| global_in = np.stack(global_hist[-ctx:], axis=0) |
| rtg_in = np.array(rtg_hist[-ctx:], dtype=np.float32) |
| ts_in = np.array(ts_hist[-ctx:], dtype=np.int64) |
| if len(action_hist) < ctx: |
| act_in = np.zeros(ctx, dtype=np.int64) |
| if action_hist: |
| act_in[-len(action_hist):] = action_hist[-ctx:] |
| else: |
| act_in = np.array(action_hist[-ctx:], dtype=np.int64) |
|
|
| local_t = torch.tensor(local_in, dtype=torch.float32).unsqueeze(0).to(device) |
| global_t = torch.tensor(global_in, dtype=torch.float32).unsqueeze(0).to(device) |
| rtg_t = torch.tensor(rtg_in, dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device) |
| act_t = torch.tensor(act_in, dtype=torch.long).unsqueeze(0).to(device) |
| ts_t = torch.tensor(ts_in, dtype=torch.long).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| action = int( |
| model.get_action(rtg_t, local_t, global_t, act_t, ts_t).item() |
| ) |
| action = max(0, min(action, int(cfg.action_dim) - 1)) |
| action_hist.append(action) |
|
|
| obs, reward, terminated, truncated, info = env.step(action) |
| done = terminated or truncated |
| current_rtg -= float(reward) |
| t += 1 |
|
|
| if info.get("won", False): |
| wins += 1 |
| total_steps += t |
| finally: |
| env.close() |
|
|
| return wins / n_episodes, total_steps / n_episodes |
|
|
|
|
| def _train_dt( |
| cfg: SimpleNamespace, |
| log: Logger, |
| log_dir: str, |
| seed: int, |
| ) -> tuple[_DecisionTransformer, dict[str, float]]: |
| """Train a Decision Transformer baseline. Returns (model, seed_metrics).""" |
|
|
| device = torch.device(cfg.device) |
| context_len = int(cfg.baselines_dt_context_len) |
| max_ep_len = int(cfg.baselines_dt_max_ep_len) |
| eval_max_steps = int(cfg.baselines_dt_eval_max_steps) |
| n_eval = _eval_episodes_per_env(cfg) |
| n_epochs = int(cfg.baselines_dt_epochs) |
|
|
| logger.info("Collecting oracle demonstrations for DT...") |
| trajectories = _collect_dt_trajectories(cfg) |
| if not trajectories: |
| raise RuntimeError("DT oracle collection produced zero trajectories") |
|
|
| traj_lengths = [len(t["actions"]) for t in trajectories] |
| logger.info( |
| "DT dataset: %d trajectories, %d transitions (len: min=%d max=%d mean=%.1f)", |
| len(trajectories), |
| sum(traj_lengths), |
| min(traj_lengths), |
| max(traj_lengths), |
| float(np.mean(traj_lengths)), |
| ) |
| if max(traj_lengths) > max_ep_len: |
| logger.warning( |
| "Longest oracle trajectory (%d) exceeds baselines_dt_max_ep_len (%d); " |
| "positions will be clamped.", |
| max(traj_lengths), |
| max_ep_len, |
| ) |
|
|
| target_return = float(max(t["returns_to_go"][0] for t in trajectories)) |
|
|
| dataset = _DTDataset( |
| trajectories, |
| context_len=context_len, |
| max_ep_len=max_ep_len, |
| n_actions=int(cfg.action_dim), |
| ) |
| loader = DataLoader( |
| dataset, |
| batch_size=int(cfg.baselines_dt_batch_size), |
| shuffle=True, |
| num_workers=4, |
| pin_memory=torch.cuda.is_available(), |
| ) |
|
|
| model = _DecisionTransformer( |
| n_actions=int(cfg.action_dim), |
| embed_dim=int(cfg.baselines_dt_embed_dim), |
| n_heads=int(cfg.baselines_dt_n_heads), |
| n_layers=int(cfg.baselines_dt_n_layers), |
| context_len=context_len, |
| max_ep_len=max_ep_len, |
| crop_h=int(cfg.crop_size), |
| crop_w=int(cfg.crop_size), |
| map_h=int(cfg.map_h), |
| map_w=int(cfg.map_w), |
| ).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| logger.info("DT parameters: %d", n_params) |
|
|
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=float(cfg.baselines_dt_lr), |
| weight_decay=float(cfg.weight_decay), |
| ) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=n_epochs, |
| ) |
|
|
| for epoch in range(n_epochs): |
| model.train() |
| total_loss = 0.0 |
| n_batches = 0 |
| for batch in loader: |
| local = batch["local"].to(device) |
| glob = batch["global"].to(device) |
| actions = batch["actions"].to(device) |
| rtg = batch["returns_to_go"].to(device) |
| timesteps = batch["timesteps"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
|
|
| logits = model(rtg, local, glob, actions, timesteps, attention_mask) |
| logits_flat = logits.reshape(-1, int(cfg.action_dim)) |
| targets_flat = actions.reshape(-1) |
| mask_flat = attention_mask.reshape(-1) |
| ce = nn.functional.cross_entropy( |
| logits_flat, targets_flat, reduction="none", |
| ) |
| loss = (ce * mask_flat).sum() / mask_flat.sum().clamp(min=1.0) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| total_loss += loss.item() |
| n_batches += 1 |
| scheduler.step() |
| avg_loss = total_loss / max(1, n_batches) |
| log.log( |
| { |
| "train/dt_loss": avg_loss, |
| "train/lr": float(scheduler.get_last_lr()[0]), |
| "train/epoch": epoch + 1, |
| }, |
| step=epoch + 1, |
| ) |
| logger.info( |
| "DT epoch %02d/%02d | loss=%.4f | lr=%.2e", |
| epoch + 1, |
| n_epochs, |
| avg_loss, |
| float(scheduler.get_last_lr()[0]), |
| ) |
|
|
| seed_metrics: dict[str, float] = {} |
| logger.info("DT eval target return = %.2f", target_return) |
| for split, env_list in (("ID", cfg.id_envs), ("OOD", cfg.ood_envs)): |
| logger.info("--- DT %s evaluation (seed=%d) ---", split, seed) |
| for env_id in env_list: |
| short = _short(env_id) |
| win_rate, avg_steps = _eval_dt( |
| model, |
| env_id, |
| cfg, |
| target_return=target_return, |
| n_episodes=n_eval, |
| max_ep_len=max_ep_len, |
| eval_max_steps=eval_max_steps, |
| context_len=context_len, |
| ) |
| seed_metrics[f"{split}/{short}/win_rate"] = win_rate * 100 |
| seed_metrics[f"{split}/{short}/avg_steps"] = avg_steps |
| logger.info( |
| "%-30s | win_rate=%5.1f%% | avg_steps=%5.1f", |
| short, win_rate * 100, avg_steps, |
| ) |
| log.log(seed_metrics, step=n_epochs + 1) |
| return model, seed_metrics |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _build_sb3_model( |
| algo: str, |
| train_env: SubprocVecEnv, |
| cfg: SimpleNamespace, |
| seed: int, |
| tb_log_dir: str, |
| ): |
| """Construct one of {ppo, a2c, dqn, ppo-rnn} with the MiniHack CNN.""" |
|
|
| policy_kwargs = { |
| "features_extractor_class": _MiniHackCNN, |
| "features_extractor_kwargs": {"features_dim": 256}, |
| } |
| if algo == "ppo": |
| return PPO( |
| "MultiInputPolicy", train_env, policy_kwargs=policy_kwargs, |
| verbose=1, tensorboard_log=tb_log_dir, seed=seed, |
| ) |
| if algo == "ppo-rnn": |
| return RecurrentPPO( |
| "MultiInputLstmPolicy", train_env, policy_kwargs=policy_kwargs, |
| verbose=1, tensorboard_log=tb_log_dir, seed=seed, |
| ) |
| if algo == "a2c": |
| return A2C( |
| "MultiInputPolicy", train_env, policy_kwargs=policy_kwargs, |
| verbose=1, tensorboard_log=tb_log_dir, seed=seed, |
| ) |
| if algo == "dqn": |
| return DQN( |
| "MultiInputPolicy", train_env, policy_kwargs=policy_kwargs, |
| verbose=1, tensorboard_log=tb_log_dir, seed=seed, |
| buffer_size=int(cfg.baselines_dqn_buffer_size), |
| ) |
| raise ValueError(f"Unknown SB3 algo: {algo!r}") |
|
|
|
|
| def _build_sb3_callbacks( |
| cfg: SimpleNamespace, |
| train_env: SubprocVecEnv, |
| log_dir: str, |
| model_dir: str, |
| ) -> CallbackList: |
| callbacks: list = [WandbCallback(model_save_path=model_dir)] |
| n_eval = _eval_episodes_per_env(cfg) |
| eval_freq = max( |
| 1, int(cfg.baselines_eval_freq_env_steps) // train_env.num_envs, |
| ) |
| for env_id in cfg.id_envs: |
| short = _short(env_id) |
| eval_env = SubprocVecEnv( |
| [_make_sb3_env_fn(env_id, cfg, f"{log_dir}/eval_id/{env_id}")] |
| ) |
| callbacks.append( |
| _PrefixedEvalCallback( |
| eval_env, |
| prefix=f"ID/{short}", |
| best_model_save_path=f"{model_dir}/best_{env_id}/", |
| log_path=f"{log_dir}/eval_id/{env_id}/", |
| eval_freq=eval_freq, |
| n_eval_episodes=n_eval, |
| deterministic=True, |
| ) |
| ) |
| for env_id in cfg.ood_envs: |
| short = _short(env_id) |
| eval_env = SubprocVecEnv( |
| [_make_sb3_env_fn(env_id, cfg, f"{log_dir}/eval_ood/{env_id}")] |
| ) |
| callbacks.append( |
| _PrefixedEvalCallback( |
| eval_env, |
| prefix=f"OOD/{short}", |
| best_model_save_path=None, |
| log_path=f"{log_dir}/eval_ood/{env_id}/", |
| eval_freq=eval_freq, |
| n_eval_episodes=n_eval, |
| deterministic=True, |
| ) |
| ) |
| return CallbackList(callbacks) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _aggregate( |
| all_seed_results: list[dict[str, Any]], |
| ) -> dict[str, dict[str, float | list[float]]]: |
| """Compute mean/std across seeds for every shared metric key.""" |
|
|
| if not all_seed_results: |
| return {} |
| metric_keys = [k for k in all_seed_results[0].keys() if k != "seed"] |
| agg: dict[str, dict[str, float | list[float]]] = {} |
| for key in metric_keys: |
| values = [r[key] for r in all_seed_results if key in r] |
| if values: |
| agg[key] = { |
| "mean": float(np.mean(values)), |
| "std": float(np.std(values)), |
| "values": [float(v) for v in values], |
| } |
| return agg |
|
|
|
|
| def _print_aggregated(seeds: list[int], agg: dict[str, dict[str, Any]]) -> None: |
| if not agg: |
| logger.info("No per-environment metrics to aggregate (RL eval is callback-driven)") |
| return |
| logger.info("Aggregated results across %d seeds: %s", len(seeds), seeds) |
| for split in ("ID", "OOD"): |
| env_metrics: dict[str, dict[str, dict[str, Any]]] = {} |
| for key, stats in agg.items(): |
| if not key.startswith(f"{split}/"): |
| continue |
| _split, env_name, metric_name = key.split("/", 2) |
| env_metrics.setdefault(env_name, {})[metric_name] = stats |
| if not env_metrics: |
| continue |
| logger.info("--- %s environments ---", split) |
| for env_name, metrics in sorted(env_metrics.items()): |
| wr = metrics.get("win_rate", {}) |
| steps = metrics.get("avg_steps", {}) |
| logger.info( |
| "%-30s | win_rate=%5.1f%% +/- %4.1f | avg_steps=%5.1f +/- %4.1f", |
| env_name, |
| wr.get("mean", 0.0), |
| wr.get("std", 0.0), |
| steps.get("mean", 0.0), |
| steps.get("std", 0.0), |
| ) |
|
|
|
|
| def _save_aggregated( |
| out_path: Path, |
| algo: str, |
| seeds: list[int], |
| all_seed_results: list[dict[str, Any]], |
| agg: dict[str, dict[str, Any]], |
| ) -> None: |
| payload = { |
| "algorithm": algo, |
| "seeds": seeds, |
| "n_seeds": len(seeds), |
| "per_seed_results": all_seed_results, |
| "aggregated": { |
| k: {"mean": v["mean"], "std": v["std"]} for k, v in agg.items() |
| }, |
| } |
| out_path.write_bytes(orjson.dumps(payload, option=orjson.OPT_INDENT_2)) |
| logger.info("Aggregated results written to %s", out_path) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def run_baselines( |
| cfg: SimpleNamespace, |
| algo: str, |
| seeds: list[int] | None = None, |
| output_path: str | None = None, |
| ) -> None: |
| """Train and evaluate one baseline algorithm across one or more seeds. |
| |
| Args: |
| cfg: Project config namespace (must contain ``baselines_*`` keys). |
| algo: One of ``ppo``, ``a2c``, ``dqn``, ``ppo-rnn``, ``bc``, ``dt``. |
| seeds: Optional list of seeds. ``None`` -> ``[cfg.seed]`` (or |
| a single seed of ``0`` if ``cfg.seed`` is ``None``). |
| output_path: Optional override for the aggregated-results JSON |
| destination. When ``None``, results land under |
| ``cfg.baselines_output_dir``. |
| """ |
|
|
| if algo not in ALL_BASELINE_ALGOS: |
| raise ValueError( |
| f"Unknown algo {algo!r}. Choose one of {ALL_BASELINE_ALGOS}." |
| ) |
|
|
| if seeds is None: |
| seeds = [cfg.seed if cfg.seed is not None else 0] |
| if not seeds: |
| raise ValueError("seeds must be non-empty") |
|
|
| out_dir = _resolve_output_dir(cfg, None) |
| if output_path is not None: |
| agg_json_path = Path(output_path) |
| agg_json_path.parent.mkdir(parents=True, exist_ok=True) |
| else: |
| agg_json_path = out_dir / f"results_{algo}_{len(seeds)}seeds.json" |
|
|
| logger.info( |
| "Running baseline %s on %d seed(s): %s (output -> %s)", |
| algo, len(seeds), seeds, agg_json_path, |
| ) |
|
|
| all_seed_results: list[dict[str, Any]] = [] |
| n_envs_per_id = int(cfg.baselines_n_envs_per_id) |
|
|
| for seed_idx, seed in enumerate(seeds): |
| logger.info( |
| "============================================================\n" |
| " %s seed %d (%d/%d)\n" |
| "============================================================", |
| algo.upper(), seed, seed_idx + 1, len(seeds), |
| ) |
| _seed_everything(seed) |
|
|
| run_name = f"{algo}-multitask-seed{seed}" |
| log = _init_baseline_logger(cfg, run_name) |
| run_id = ( |
| log._run.id |
| if log._use_wandb and log._run is not None |
| else f"local-{algo}-seed{seed}" |
| ) |
| log_dir = str(out_dir / "logs" / run_id) |
| model_dir = str(out_dir / "models" / run_id) |
| os.makedirs(log_dir, exist_ok=True) |
| os.makedirs(model_dir, exist_ok=True) |
|
|
| seed_results: dict[str, Any] = {"seed": seed} |
| try: |
| if algo == "dt": |
| model, dt_metrics = _train_dt(cfg, log, log_dir, seed) |
| seed_results.update(dt_metrics) |
| torch.save( |
| { |
| "model_state_dict": model.state_dict(), |
| "config": { |
| "n_actions": int(cfg.action_dim), |
| "embed_dim": int(cfg.baselines_dt_embed_dim), |
| "n_heads": int(cfg.baselines_dt_n_heads), |
| "n_layers": int(cfg.baselines_dt_n_layers), |
| "context_len": int(cfg.baselines_dt_context_len), |
| "max_ep_len": int(cfg.baselines_dt_max_ep_len), |
| }, |
| }, |
| f"{model_dir}/dt_final_seed{seed}.pt", |
| ) |
| else: |
| |
| train_env_fns = [ |
| _make_sb3_env_fn(env_id, cfg, log_dir) |
| for env_id in list(cfg.id_envs) * n_envs_per_id |
| ] |
| train_env = SubprocVecEnv(train_env_fns) |
| try: |
| if algo == "bc": |
| policy, bc_metrics = _train_bc( |
| cfg, train_env, log, log_dir, seed, |
| ) |
| seed_results.update(bc_metrics) |
| policy.save(f"{model_dir}/bc_final_seed{seed}") |
| else: |
| sb3_model = _build_sb3_model( |
| algo, train_env, cfg, seed, |
| tb_log_dir=str(out_dir / "tb" / run_id), |
| ) |
| callbacks = _build_sb3_callbacks( |
| cfg, train_env, log_dir, model_dir, |
| ) |
| logger.info( |
| "Training %s for %d env-steps across %d ID maps " |
| "(%d parallel envs)...", |
| algo.upper(), |
| int(cfg.total_timesteps), |
| len(cfg.id_envs), |
| train_env.num_envs, |
| ) |
| sb3_model.learn( |
| total_timesteps=int(cfg.total_timesteps), |
| callback=callbacks, |
| ) |
| sb3_model.save(f"{model_dir}/{algo}_final_seed{seed}") |
| finally: |
| train_env.close() |
|
|
| all_seed_results.append(seed_results) |
| finally: |
| log.finish() |
| logger.info("%s seed %d complete.", algo.upper(), seed) |
|
|
| agg = _aggregate(all_seed_results) |
| _print_aggregated(seeds, agg) |
| if agg: |
| _save_aggregated(agg_json_path, algo, seeds, all_seed_results, agg) |
| |
| |
| summary_run_name = f"{algo}-multitask-summary" |
| summary_log = _init_baseline_logger(cfg, summary_run_name) |
| try: |
| summary_payload: dict[str, float] = {} |
| for key, stats in agg.items(): |
| summary_payload[f"summary/{key}/mean"] = stats["mean"] |
| summary_payload[f"summary/{key}/std"] = stats["std"] |
| summary_log.log_summary(summary_payload) |
| finally: |
| summary_log.finish() |
| logger.info("All %d seed(s) complete.", len(seeds)) |
|
|