remdm-minihack / src /planners /baselines.py
MathisW78's picture
Demo notebook payload (source + checkpoint + assets)
f748552 verified
"""SB3 + Decision Transformer baselines for the ReMDM diffusion planner.
This module wraps standard discrete-action RL baselines (PPO, A2C, DQN,
recurrent PPO) plus two imitation baselines (Behavioural Cloning and
Decision Transformer) into the project's unified config + dispatch
surface so they can be compared head-to-head against the DAgger /
offline-BC diffusion planner on the same MiniHack environments.
Entry point: :func:`run_baselines`.
Hyperparameters live in ``configs/defaults.yaml`` under the
``baselines_*`` namespace; the unified env-step training budget
(``cfg.total_timesteps``) is shared with DAgger and offline BC.
W&B logging routes through the project's :class:`Logger` (with the W&B
project temporarily swapped to ``cfg.baselines_wandb_project``); SB3's
standard ``WandbCallback`` piggybacks on the active run and syncs its
tensorboard scalars automatically. No file in this module calls
``wandb.log(...)`` directly.
"""
from __future__ import annotations
import logging
import os
import random
from pathlib import Path
from types import SimpleNamespace
from typing import Any
import gymnasium as gym
import numpy as np
import orjson
import torch
import torch.nn as nn
from sb3_contrib import RecurrentPPO
from stable_baselines3 import A2C, DQN, PPO
from stable_baselines3.common.callbacks import CallbackList, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import SubprocVecEnv
from torch.utils.data import DataLoader, Dataset
from wandb.integration.sb3 import WandbCallback
from src.envs.minihack_env import (
AdvancedObservationEnv,
collect_oracle_trajectory,
)
from src.planners.logging import Logger
logger = logging.getLogger(__name__)
SB3_RL_ALGOS: tuple[str, ...] = ("ppo", "a2c", "dqn", "ppo-rnn")
IMITATION_ALGOS: tuple[str, ...] = ("bc", "dt")
ALL_BASELINE_ALGOS: tuple[str, ...] = SB3_RL_ALGOS + IMITATION_ALGOS
# =============================================================================
# Observation wrapper for SB3 dict-policies
# =============================================================================
class _SB3MiniHackWrapper(gym.Wrapper):
"""Reshape ``AdvancedObservationEnv`` tuple obs into an SB3 dict obs.
The underlying env returns ``(local_crop, global_map)`` with shapes
``(crop, crop)`` and ``(map_h, map_w)``; SB3's ``MultiInputPolicy``
needs a ``Dict`` space with explicit channel dims. Also remaps
``info["won"]`` -> ``info["is_success"]`` so SB3's success tracking
reports our win rate.
"""
def __init__(self, env: AdvancedObservationEnv) -> None:
super().__init__(env)
local_h, local_w = env.observation_space.shape
cfg = env._cfg # AdvancedObservationEnv stores cfg here
self.observation_space = gym.spaces.Dict(
{
"local": gym.spaces.Box(
low=0, high=6000, shape=(1, local_h, local_w), dtype=np.int16,
),
"global": gym.spaces.Box(
low=0, high=6000, shape=(1, cfg.map_h, cfg.map_w), dtype=np.int16,
),
}
)
def reset(self, **kwargs: Any) -> tuple[dict[str, np.ndarray], dict]:
(local, glob), info = self.env.reset(**kwargs)
return self._pack(local, glob), info
def step(
self, action: int,
) -> tuple[dict[str, np.ndarray], float, bool, bool, dict]:
(local, glob), reward, terminated, truncated, info = self.env.step(action)
if "won" in info:
info["is_success"] = info["won"]
return self._pack(local, glob), reward, terminated, truncated, info
@staticmethod
def _pack(
local: np.ndarray, glob: np.ndarray,
) -> dict[str, np.ndarray]:
return {
"local": np.expand_dims(local, axis=0), # [1, crop, crop]
"global": np.expand_dims(glob, axis=0), # [1, H, W]
}
# =============================================================================
# CNN feature extractor (shared by SB3 RL + BC)
# =============================================================================
class _MiniHackCNN(BaseFeaturesExtractor):
"""Dual-stream CNN for the SB3 dict observation.
Local stream: ``Conv(1->16, 3) -> Conv(16->32, 3)``.
Global stream: ``Conv(1->16, 5, stride 2) -> Conv(16->32, 3, stride 2)``.
Both streams are flattened and concatenated, then projected to
``features_dim`` via a single linear + ReLU.
"""
def __init__(
self, observation_space: gym.spaces.Dict, features_dim: int = 256,
) -> None:
super().__init__(observation_space, features_dim)
self.local_cnn = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Flatten(),
)
self.global_cnn = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=2),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2),
nn.ReLU(),
nn.Flatten(),
)
with torch.no_grad():
dummy_loc = torch.zeros(1, *observation_space["local"].shape)
dummy_glob = torch.zeros(1, *observation_space["global"].shape)
n_flatten = (
self.local_cnn(dummy_loc).shape[1]
+ self.global_cnn(dummy_glob).shape[1]
)
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(
self, observations: dict[str, torch.Tensor],
) -> torch.Tensor:
loc = self.local_cnn(observations["local"].float()) # [B, F_l]
glob = self.global_cnn(observations["global"].float()) # [B, F_g]
return self.linear(torch.cat([loc, glob], dim=1))
# =============================================================================
# Decision Transformer
# =============================================================================
class _MiniHackStateEncoder(nn.Module):
"""CNN encoder mapping a (local, global) obs pair to a state embedding."""
def __init__(
self,
embed_dim: int = 128,
crop_h: int = 9,
crop_w: int = 9,
map_h: int = 21,
map_w: int = 79,
) -> None:
super().__init__()
self.local_cnn = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Flatten(),
)
self.global_cnn = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=2),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2),
nn.ReLU(),
nn.Flatten(),
)
with torch.no_grad():
dummy_loc = torch.zeros(1, 1, crop_h, crop_w)
dummy_glob = torch.zeros(1, 1, map_h, map_w)
local_flat = self.local_cnn(dummy_loc).shape[1]
global_flat = self.global_cnn(dummy_glob).shape[1]
self.proj = nn.Linear(local_flat + global_flat, embed_dim)
def forward(
self, local_obs: torch.Tensor, global_obs: torch.Tensor,
) -> torch.Tensor:
# Accepts (B, T, 1, H, W) or (B, 1, H, W).
if local_obs.dim() == 5:
B, T = local_obs.shape[:2]
local_obs = local_obs.view(B * T, *local_obs.shape[2:])
global_obs = global_obs.view(B * T, *global_obs.shape[2:])
reshape = True
else:
B, T = local_obs.shape[0], 1
reshape = False
loc_feat = self.local_cnn(local_obs.float()) # [B*T, F_l]
glob_feat = self.global_cnn(global_obs.float()) # [B*T, F_g]
out = self.proj(torch.cat([loc_feat, glob_feat], dim=-1)) # [B*T, D]
if reshape:
out = out.view(B, T, -1)
return out
class _DecisionTransformer(nn.Module):
"""Causal Decision Transformer over interleaved (R, s, a) tokens."""
def __init__(
self,
n_actions: int,
embed_dim: int = 128,
n_heads: int = 4,
n_layers: int = 3,
context_len: int = 30,
max_ep_len: int = 500,
dropout: float = 0.1,
crop_h: int = 9,
crop_w: int = 9,
map_h: int = 21,
map_w: int = 79,
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.context_len = context_len
self.n_actions = n_actions
self.max_ep_len = max_ep_len
self.state_encoder = _MiniHackStateEncoder(
embed_dim, crop_h, crop_w, map_h, map_w,
)
self.action_embed = nn.Embedding(n_actions + 1, embed_dim) # +1 for pad
self.return_embed = nn.Linear(1, embed_dim)
self.pos_embed = nn.Embedding(max_ep_len, embed_dim)
self.token_type_embed = nn.Embedding(3, embed_dim)
self.embed_ln = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=n_heads,
dim_feedforward=embed_dim * 4,
dropout=dropout,
activation="gelu",
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.action_head = nn.Linear(embed_dim, n_actions)
self.apply(self._init_weights)
@staticmethod
def _init_weights(module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(
self,
returns_to_go: torch.Tensor, # [B, T, 1]
local_obs: torch.Tensor, # [B, T, 1, H_l, W_l]
global_obs: torch.Tensor, # [B, T, 1, H_g, W_g]
actions: torch.Tensor, # [B, T]
timesteps: torch.Tensor, # [B, T]
attention_mask: torch.Tensor | None = None, # [B, T]
) -> torch.Tensor:
B, T = returns_to_go.shape[:2]
device = returns_to_go.device
rtg_embed = self.return_embed(returns_to_go) # [B, T, D]
state_embed = self.state_encoder(local_obs, global_obs) # [B, T, D]
action_embed = self.action_embed(actions) # [B, T, D]
pos_embed = self.pos_embed(timesteps) # [B, T, D]
rtg_embed = rtg_embed + pos_embed + self.token_type_embed.weight[0]
state_embed = state_embed + pos_embed + self.token_type_embed.weight[1]
action_embed = action_embed + pos_embed + self.token_type_embed.weight[2]
# Interleave (R_0, s_0, a_0, R_1, s_1, a_1, ...) -> [B, 3T, D]
stacked = torch.stack([rtg_embed, state_embed, action_embed], dim=2)
stacked = stacked.view(B, 3 * T, self.embed_dim)
stacked = self.dropout(self.embed_ln(stacked))
seq_len = 3 * T
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=device), diagonal=1,
).bool()
key_padding_mask = None
if attention_mask is not None:
expanded = attention_mask.unsqueeze(-1).repeat(1, 1, 3).view(B, 3 * T)
key_padding_mask = expanded == 0
hidden = self.transformer(
stacked, mask=causal_mask, src_key_padding_mask=key_padding_mask,
)
# State token positions are 1, 4, 7, ... -> stride 3.
state_hidden = hidden[:, 1::3, :] # [B, T, D]
return self.action_head(state_hidden) # [B, T, A]
@torch.no_grad()
def get_action(
self,
returns_to_go: torch.Tensor,
local_obs: torch.Tensor,
global_obs: torch.Tensor,
actions: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
self.eval()
logits = self.forward(
returns_to_go, local_obs, global_obs, actions, timesteps,
)
return logits[:, -1, :].argmax(dim=-1)
class _DTDataset(Dataset):
"""Sliding-window dataset over Decision Transformer trajectories."""
def __init__(
self,
trajectories: list[dict[str, np.ndarray]],
context_len: int,
max_ep_len: int,
n_actions: int,
) -> None:
self.trajectories = trajectories
self.context_len = context_len
self.max_ep_len = max_ep_len
self.n_actions = n_actions
self.indices: list[tuple[int, int]] = [
(traj_idx, start)
for traj_idx, traj in enumerate(trajectories)
for start in range(len(traj["actions"]))
]
def __len__(self) -> int:
return len(self.indices)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
traj_idx, start = self.indices[idx]
traj = self.trajectories[traj_idx]
traj_len = len(traj["actions"])
end = min(start + self.context_len, traj_len)
actual_len = end - start
local = traj["local"][start:end].copy()
glob = traj["global"][start:end].copy()
actions = traj["actions"][start:end].copy()
rtg = traj["returns_to_go"][start:end].copy()
timesteps = np.arange(start, end)
# Clamp to valid embedding ranges.
timesteps = np.clip(timesteps, 0, self.max_ep_len - 1)
actions = np.clip(actions, 0, self.n_actions - 1)
pad_len = self.context_len - actual_len
if pad_len > 0:
local = np.pad(
local, ((0, pad_len), (0, 0), (0, 0), (0, 0)), mode="constant",
)
glob = np.pad(
glob, ((0, pad_len), (0, 0), (0, 0), (0, 0)), mode="constant",
)
actions = np.pad(actions, (0, pad_len), mode="constant")
rtg = np.pad(rtg, (0, pad_len), mode="constant")
timesteps = np.pad(timesteps, (0, pad_len), mode="constant")
attention_mask = np.zeros(self.context_len, dtype=np.float32)
attention_mask[:actual_len] = 1.0
return {
"local": torch.tensor(local, dtype=torch.float32),
"global": torch.tensor(glob, dtype=torch.float32),
"actions": torch.tensor(actions, dtype=torch.long),
"returns_to_go": torch.tensor(rtg, dtype=torch.float32).unsqueeze(-1),
"timesteps": torch.tensor(timesteps, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.float32),
}
# =============================================================================
# SB3 callbacks + env factory
# =============================================================================
class _PrefixedEvalCallback(EvalCallback):
"""``EvalCallback`` that records mean_reward / avg_steps / win_rate
under a unique per-environment prefix.
SB3 truncates metric names at 36 chars, which collides on long
MiniHack env IDs; the prefix lets us strip ``MiniHack-`` / ``-v0``
cleanly.
"""
def __init__(
self, eval_env: SubprocVecEnv, prefix: str, **kwargs: Any,
) -> None:
super().__init__(eval_env, **kwargs)
self.prefix = prefix
def _on_step(self) -> bool:
cont = super()._on_step()
if self.evaluations_results:
self.logger.record(
f"{self.prefix}/mean_reward", float(np.mean(self.evaluations_results[-1])),
)
self.logger.record(
f"{self.prefix}/avg_steps", float(np.mean(self.evaluations_length[-1])),
)
if self.evaluations_successes:
self.logger.record(
f"{self.prefix}/win_rate",
float(np.mean(self.evaluations_successes[-1])),
)
return cont
def _make_sb3_env_fn(env_id: str, cfg: SimpleNamespace, log_dir: str):
"""Return a picklable thunk that builds one wrapped+monitored env."""
def _init() -> Monitor:
os.makedirs(log_dir, exist_ok=True)
env = AdvancedObservationEnv(env_id, des_file=None, cfg=cfg)
env = _SB3MiniHackWrapper(env)
return Monitor(env, log_dir)
return _init
# =============================================================================
# Helpers
# =============================================================================
def _short(env_id: str) -> str:
return env_id.replace("MiniHack-", "").replace("-v0", "")
def _eval_episodes_per_env(cfg: SimpleNamespace) -> int:
override = getattr(cfg, "baselines_eval_episodes_per_env", None)
if override is not None:
return int(override)
return int(cfg.eval_episodes_per_env)
def _seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def _resolve_output_dir(cfg: SimpleNamespace, override: str | None) -> Path:
if override:
out = Path(override)
else:
out = Path(cfg.baselines_output_dir)
out.mkdir(parents=True, exist_ok=True)
return out
def _init_baseline_logger(
cfg: SimpleNamespace, run_name: str,
) -> Logger:
"""Init the project Logger with W&B project swapped to baselines.
Mutates ``cfg.wandb_project`` / ``cfg.wandb_run_name`` /
``cfg.wandb_resume_id`` for the duration of the call so the existing
Logger constructor picks them up. We deliberately do not restore the
originals — each baseline seed reuses this helper, and main.py exits
after ``run_baselines`` returns.
"""
project_override = getattr(cfg, "baselines_wandb_project", None)
if project_override:
cfg.wandb_project = project_override
cfg.wandb_run_name = run_name
cfg.wandb_resume_id = None
return Logger(cfg)
# =============================================================================
# BC training
# =============================================================================
def _collect_bc_dataset(
cfg: SimpleNamespace,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Roll out the BFS oracle on each ID env and stack flat (s, a) pairs."""
n_per_env = int(cfg.baselines_bc_oracle_episodes_per_env)
locals_, globals_, actions_ = [], [], []
for env_id in cfg.id_envs:
for traj_seed in range(n_per_env):
traj = collect_oracle_trajectory(env_id, traj_seed, cfg)
if traj is None:
continue
# (T, H, W) -> (T, 1, H, W)
locals_.append(np.expand_dims(traj["local"], axis=1))
globals_.append(np.expand_dims(traj["global"], axis=1))
actions_.append(traj["actions"])
if not actions_:
raise RuntimeError("BC oracle collection produced zero trajectories")
return (
np.concatenate(locals_, axis=0),
np.concatenate(globals_, axis=0),
np.concatenate(actions_, axis=0),
)
class _BCDataset(Dataset):
def __init__(
self, loc: np.ndarray, glob: np.ndarray, acts: np.ndarray,
) -> None:
self.loc = torch.tensor(loc, dtype=torch.float32)
self.glob = torch.tensor(glob, dtype=torch.float32)
self.acts = torch.tensor(acts, dtype=torch.int64)
def __len__(self) -> int:
return len(self.acts)
def __getitem__(
self, idx: int,
) -> dict[str, dict[str, torch.Tensor] | torch.Tensor]:
return {
"obs": {"local": self.loc[idx], "global": self.glob[idx]},
"acts": self.acts[idx],
}
def _eval_sb3_policy_manually(
policy: ActorCriticPolicy,
env_id: str,
cfg: SimpleNamespace,
log_dir: str,
n_episodes: int,
) -> tuple[float, float]:
"""Run ``policy.predict`` on a Monitor-wrapped vec env and return
(win_rate, avg_steps)."""
eval_env = SubprocVecEnv([_make_sb3_env_fn(env_id, cfg, log_dir)])
try:
obs = eval_env.reset()
wins = 0
total_steps = 0
completed = 0
while completed < n_episodes:
action, _ = policy.predict(obs, deterministic=True)
obs, _rewards, dones, infos = eval_env.step(action)
if dones[0]:
completed += 1
if infos[0].get("won", False):
wins += 1
total_steps += infos[0]["episode"]["l"]
finally:
eval_env.close()
return wins / n_episodes, total_steps / n_episodes
def _train_bc(
cfg: SimpleNamespace,
train_env: SubprocVecEnv,
log: Logger,
log_dir: str,
seed: int,
) -> tuple[ActorCriticPolicy, dict[str, float]]:
"""Train a Behavioural Cloning baseline. Returns (policy, seed_metrics)."""
device = torch.device(cfg.device)
n_eval = _eval_episodes_per_env(cfg)
logger.info("Collecting oracle demonstrations for BC...")
loc_arr, glob_arr, acts_arr = _collect_bc_dataset(cfg)
logger.info("BC dataset: %d transitions", len(acts_arr))
bc_loader = DataLoader(
_BCDataset(loc_arr, glob_arr, acts_arr),
batch_size=int(cfg.baselines_bc_batch_size),
shuffle=True,
num_workers=4,
pin_memory=torch.cuda.is_available(),
)
lr = float(cfg.baselines_bc_lr)
policy = ActorCriticPolicy(
observation_space=train_env.observation_space,
action_space=train_env.action_space,
lr_schedule=lambda _progress: lr,
features_extractor_class=_MiniHackCNN,
features_extractor_kwargs={"features_dim": 256},
).to(device)
n_epochs = int(cfg.baselines_bc_epochs)
optimizer = torch.optim.AdamW(
policy.parameters(),
lr=lr,
weight_decay=float(cfg.weight_decay),
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=n_epochs,
)
policy.train()
for epoch in range(n_epochs):
total_loss = 0.0
for batch in bc_loader:
obs = {k: v.to(policy.device) for k, v in batch["obs"].items()}
acts = batch["acts"].to(policy.device)
_values, log_prob, _entropy = policy.evaluate_actions(obs, acts)
loss = -log_prob.mean()
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / max(1, len(bc_loader))
current_lr = scheduler.get_last_lr()[0]
log.log(
{
"train/bc_loss": avg_loss,
"train/lr": current_lr,
"train/epoch": epoch + 1,
},
step=epoch + 1,
)
logger.info(
"BC epoch %02d/%02d | loss=%.4f | lr=%.2e",
epoch + 1, n_epochs, avg_loss, current_lr,
)
seed_metrics: dict[str, float] = {}
for split, env_list in (("ID", cfg.id_envs), ("OOD", cfg.ood_envs)):
logger.info("--- BC %s evaluation (seed=%d) ---", split, seed)
for env_id in env_list:
short = _short(env_id)
win_rate, avg_steps = _eval_sb3_policy_manually(
policy,
env_id,
cfg,
f"{log_dir}/eval_{split.lower()}/{env_id}",
n_eval,
)
seed_metrics[f"{split}/{short}/win_rate"] = win_rate * 100
seed_metrics[f"{split}/{short}/avg_steps"] = avg_steps
logger.info(
"%-30s | win_rate=%5.1f%% | avg_steps=%5.1f",
short, win_rate * 100, avg_steps,
)
log.log(seed_metrics, step=n_epochs + 1)
return policy, seed_metrics
# =============================================================================
# Decision Transformer training
# =============================================================================
def _collect_dt_trajectories(
cfg: SimpleNamespace,
) -> list[dict[str, np.ndarray]]:
"""Collect oracle trajectories with sparse reward + return-to-go labels."""
n_per_env = int(cfg.baselines_dt_oracle_episodes_per_env)
trajectories: list[dict[str, np.ndarray]] = []
for env_id in cfg.id_envs:
for traj_seed in range(n_per_env):
traj = collect_oracle_trajectory(env_id, traj_seed, cfg)
if traj is None:
continue
T = len(traj["actions"])
rewards = np.zeros(T, dtype=np.float32)
rewards[-1] = 1.0 # sparse goal reward
rtg = np.zeros(T, dtype=np.float32)
rtg[-1] = rewards[-1]
for t in range(T - 2, -1, -1):
rtg[t] = rewards[t] + rtg[t + 1]
trajectories.append(
{
"local": np.expand_dims(traj["local"], axis=1),
"global": np.expand_dims(traj["global"], axis=1),
"actions": traj["actions"],
"rewards": rewards,
"returns_to_go": rtg,
}
)
return trajectories
def _eval_dt(
model: _DecisionTransformer,
env_id: str,
cfg: SimpleNamespace,
target_return: float,
n_episodes: int,
max_ep_len: int,
eval_max_steps: int,
context_len: int,
) -> tuple[float, float]:
"""Roll out a trained Decision Transformer with target-return conditioning."""
device = torch.device(cfg.device)
env = AdvancedObservationEnv(env_id, des_file=None, cfg=cfg)
env = _SB3MiniHackWrapper(env)
model.eval()
wins = 0
total_steps = 0
try:
for _ep in range(n_episodes):
obs, _ = env.reset()
done = False
local_hist: list[np.ndarray] = []
global_hist: list[np.ndarray] = []
action_hist: list[int] = []
rtg_hist: list[float] = []
ts_hist: list[int] = []
current_rtg = float(target_return)
t = 0
info: dict = {}
while not done and t < eval_max_steps:
local_hist.append(obs["local"])
global_hist.append(obs["global"])
rtg_hist.append(current_rtg)
ts_hist.append(min(t, max_ep_len - 1))
ctx = min(len(local_hist), context_len)
local_in = np.stack(local_hist[-ctx:], axis=0)
global_in = np.stack(global_hist[-ctx:], axis=0)
rtg_in = np.array(rtg_hist[-ctx:], dtype=np.float32)
ts_in = np.array(ts_hist[-ctx:], dtype=np.int64)
if len(action_hist) < ctx:
act_in = np.zeros(ctx, dtype=np.int64)
if action_hist:
act_in[-len(action_hist):] = action_hist[-ctx:]
else:
act_in = np.array(action_hist[-ctx:], dtype=np.int64)
local_t = torch.tensor(local_in, dtype=torch.float32).unsqueeze(0).to(device)
global_t = torch.tensor(global_in, dtype=torch.float32).unsqueeze(0).to(device)
rtg_t = torch.tensor(rtg_in, dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
act_t = torch.tensor(act_in, dtype=torch.long).unsqueeze(0).to(device)
ts_t = torch.tensor(ts_in, dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
action = int(
model.get_action(rtg_t, local_t, global_t, act_t, ts_t).item()
)
action = max(0, min(action, int(cfg.action_dim) - 1))
action_hist.append(action)
obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
current_rtg -= float(reward)
t += 1
if info.get("won", False):
wins += 1
total_steps += t
finally:
env.close()
return wins / n_episodes, total_steps / n_episodes
def _train_dt(
cfg: SimpleNamespace,
log: Logger,
log_dir: str,
seed: int,
) -> tuple[_DecisionTransformer, dict[str, float]]:
"""Train a Decision Transformer baseline. Returns (model, seed_metrics)."""
device = torch.device(cfg.device)
context_len = int(cfg.baselines_dt_context_len)
max_ep_len = int(cfg.baselines_dt_max_ep_len)
eval_max_steps = int(cfg.baselines_dt_eval_max_steps)
n_eval = _eval_episodes_per_env(cfg)
n_epochs = int(cfg.baselines_dt_epochs)
logger.info("Collecting oracle demonstrations for DT...")
trajectories = _collect_dt_trajectories(cfg)
if not trajectories:
raise RuntimeError("DT oracle collection produced zero trajectories")
traj_lengths = [len(t["actions"]) for t in trajectories]
logger.info(
"DT dataset: %d trajectories, %d transitions (len: min=%d max=%d mean=%.1f)",
len(trajectories),
sum(traj_lengths),
min(traj_lengths),
max(traj_lengths),
float(np.mean(traj_lengths)),
)
if max(traj_lengths) > max_ep_len:
logger.warning(
"Longest oracle trajectory (%d) exceeds baselines_dt_max_ep_len (%d); "
"positions will be clamped.",
max(traj_lengths),
max_ep_len,
)
target_return = float(max(t["returns_to_go"][0] for t in trajectories))
dataset = _DTDataset(
trajectories,
context_len=context_len,
max_ep_len=max_ep_len,
n_actions=int(cfg.action_dim),
)
loader = DataLoader(
dataset,
batch_size=int(cfg.baselines_dt_batch_size),
shuffle=True,
num_workers=4,
pin_memory=torch.cuda.is_available(),
)
model = _DecisionTransformer(
n_actions=int(cfg.action_dim),
embed_dim=int(cfg.baselines_dt_embed_dim),
n_heads=int(cfg.baselines_dt_n_heads),
n_layers=int(cfg.baselines_dt_n_layers),
context_len=context_len,
max_ep_len=max_ep_len,
crop_h=int(cfg.crop_size),
crop_w=int(cfg.crop_size),
map_h=int(cfg.map_h),
map_w=int(cfg.map_w),
).to(device)
n_params = sum(p.numel() for p in model.parameters())
logger.info("DT parameters: %d", n_params)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=float(cfg.baselines_dt_lr),
weight_decay=float(cfg.weight_decay),
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=n_epochs,
)
for epoch in range(n_epochs):
model.train()
total_loss = 0.0
n_batches = 0
for batch in loader:
local = batch["local"].to(device)
glob = batch["global"].to(device)
actions = batch["actions"].to(device)
rtg = batch["returns_to_go"].to(device)
timesteps = batch["timesteps"].to(device)
attention_mask = batch["attention_mask"].to(device)
logits = model(rtg, local, glob, actions, timesteps, attention_mask)
logits_flat = logits.reshape(-1, int(cfg.action_dim))
targets_flat = actions.reshape(-1)
mask_flat = attention_mask.reshape(-1)
ce = nn.functional.cross_entropy(
logits_flat, targets_flat, reduction="none",
)
loss = (ce * mask_flat).sum() / mask_flat.sum().clamp(min=1.0)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
n_batches += 1
scheduler.step()
avg_loss = total_loss / max(1, n_batches)
log.log(
{
"train/dt_loss": avg_loss,
"train/lr": float(scheduler.get_last_lr()[0]),
"train/epoch": epoch + 1,
},
step=epoch + 1,
)
logger.info(
"DT epoch %02d/%02d | loss=%.4f | lr=%.2e",
epoch + 1,
n_epochs,
avg_loss,
float(scheduler.get_last_lr()[0]),
)
seed_metrics: dict[str, float] = {}
logger.info("DT eval target return = %.2f", target_return)
for split, env_list in (("ID", cfg.id_envs), ("OOD", cfg.ood_envs)):
logger.info("--- DT %s evaluation (seed=%d) ---", split, seed)
for env_id in env_list:
short = _short(env_id)
win_rate, avg_steps = _eval_dt(
model,
env_id,
cfg,
target_return=target_return,
n_episodes=n_eval,
max_ep_len=max_ep_len,
eval_max_steps=eval_max_steps,
context_len=context_len,
)
seed_metrics[f"{split}/{short}/win_rate"] = win_rate * 100
seed_metrics[f"{split}/{short}/avg_steps"] = avg_steps
logger.info(
"%-30s | win_rate=%5.1f%% | avg_steps=%5.1f",
short, win_rate * 100, avg_steps,
)
log.log(seed_metrics, step=n_epochs + 1)
return model, seed_metrics
# =============================================================================
# SB3 RL training
# =============================================================================
def _build_sb3_model(
algo: str,
train_env: SubprocVecEnv,
cfg: SimpleNamespace,
seed: int,
tb_log_dir: str,
):
"""Construct one of {ppo, a2c, dqn, ppo-rnn} with the MiniHack CNN."""
policy_kwargs = {
"features_extractor_class": _MiniHackCNN,
"features_extractor_kwargs": {"features_dim": 256},
}
if algo == "ppo":
return PPO(
"MultiInputPolicy", train_env, policy_kwargs=policy_kwargs,
verbose=1, tensorboard_log=tb_log_dir, seed=seed,
)
if algo == "ppo-rnn":
return RecurrentPPO(
"MultiInputLstmPolicy", train_env, policy_kwargs=policy_kwargs,
verbose=1, tensorboard_log=tb_log_dir, seed=seed,
)
if algo == "a2c":
return A2C(
"MultiInputPolicy", train_env, policy_kwargs=policy_kwargs,
verbose=1, tensorboard_log=tb_log_dir, seed=seed,
)
if algo == "dqn":
return DQN(
"MultiInputPolicy", train_env, policy_kwargs=policy_kwargs,
verbose=1, tensorboard_log=tb_log_dir, seed=seed,
buffer_size=int(cfg.baselines_dqn_buffer_size),
)
raise ValueError(f"Unknown SB3 algo: {algo!r}")
def _build_sb3_callbacks(
cfg: SimpleNamespace,
train_env: SubprocVecEnv,
log_dir: str,
model_dir: str,
) -> CallbackList:
callbacks: list = [WandbCallback(model_save_path=model_dir)]
n_eval = _eval_episodes_per_env(cfg)
eval_freq = max(
1, int(cfg.baselines_eval_freq_env_steps) // train_env.num_envs,
)
for env_id in cfg.id_envs:
short = _short(env_id)
eval_env = SubprocVecEnv(
[_make_sb3_env_fn(env_id, cfg, f"{log_dir}/eval_id/{env_id}")]
)
callbacks.append(
_PrefixedEvalCallback(
eval_env,
prefix=f"ID/{short}",
best_model_save_path=f"{model_dir}/best_{env_id}/",
log_path=f"{log_dir}/eval_id/{env_id}/",
eval_freq=eval_freq,
n_eval_episodes=n_eval,
deterministic=True,
)
)
for env_id in cfg.ood_envs:
short = _short(env_id)
eval_env = SubprocVecEnv(
[_make_sb3_env_fn(env_id, cfg, f"{log_dir}/eval_ood/{env_id}")]
)
callbacks.append(
_PrefixedEvalCallback(
eval_env,
prefix=f"OOD/{short}",
best_model_save_path=None,
log_path=f"{log_dir}/eval_ood/{env_id}/",
eval_freq=eval_freq,
n_eval_episodes=n_eval,
deterministic=True,
)
)
return CallbackList(callbacks)
# =============================================================================
# Aggregation
# =============================================================================
def _aggregate(
all_seed_results: list[dict[str, Any]],
) -> dict[str, dict[str, float | list[float]]]:
"""Compute mean/std across seeds for every shared metric key."""
if not all_seed_results:
return {}
metric_keys = [k for k in all_seed_results[0].keys() if k != "seed"]
agg: dict[str, dict[str, float | list[float]]] = {}
for key in metric_keys:
values = [r[key] for r in all_seed_results if key in r]
if values:
agg[key] = {
"mean": float(np.mean(values)),
"std": float(np.std(values)),
"values": [float(v) for v in values],
}
return agg
def _print_aggregated(seeds: list[int], agg: dict[str, dict[str, Any]]) -> None:
if not agg:
logger.info("No per-environment metrics to aggregate (RL eval is callback-driven)")
return
logger.info("Aggregated results across %d seeds: %s", len(seeds), seeds)
for split in ("ID", "OOD"):
env_metrics: dict[str, dict[str, dict[str, Any]]] = {}
for key, stats in agg.items():
if not key.startswith(f"{split}/"):
continue
_split, env_name, metric_name = key.split("/", 2)
env_metrics.setdefault(env_name, {})[metric_name] = stats
if not env_metrics:
continue
logger.info("--- %s environments ---", split)
for env_name, metrics in sorted(env_metrics.items()):
wr = metrics.get("win_rate", {})
steps = metrics.get("avg_steps", {})
logger.info(
"%-30s | win_rate=%5.1f%% +/- %4.1f | avg_steps=%5.1f +/- %4.1f",
env_name,
wr.get("mean", 0.0),
wr.get("std", 0.0),
steps.get("mean", 0.0),
steps.get("std", 0.0),
)
def _save_aggregated(
out_path: Path,
algo: str,
seeds: list[int],
all_seed_results: list[dict[str, Any]],
agg: dict[str, dict[str, Any]],
) -> None:
payload = {
"algorithm": algo,
"seeds": seeds,
"n_seeds": len(seeds),
"per_seed_results": all_seed_results,
"aggregated": {
k: {"mean": v["mean"], "std": v["std"]} for k, v in agg.items()
},
}
out_path.write_bytes(orjson.dumps(payload, option=orjson.OPT_INDENT_2))
logger.info("Aggregated results written to %s", out_path)
# =============================================================================
# Public entry point
# =============================================================================
def run_baselines(
cfg: SimpleNamespace,
algo: str,
seeds: list[int] | None = None,
output_path: str | None = None,
) -> None:
"""Train and evaluate one baseline algorithm across one or more seeds.
Args:
cfg: Project config namespace (must contain ``baselines_*`` keys).
algo: One of ``ppo``, ``a2c``, ``dqn``, ``ppo-rnn``, ``bc``, ``dt``.
seeds: Optional list of seeds. ``None`` -> ``[cfg.seed]`` (or
a single seed of ``0`` if ``cfg.seed`` is ``None``).
output_path: Optional override for the aggregated-results JSON
destination. When ``None``, results land under
``cfg.baselines_output_dir``.
"""
if algo not in ALL_BASELINE_ALGOS:
raise ValueError(
f"Unknown algo {algo!r}. Choose one of {ALL_BASELINE_ALGOS}."
)
if seeds is None:
seeds = [cfg.seed if cfg.seed is not None else 0]
if not seeds:
raise ValueError("seeds must be non-empty")
out_dir = _resolve_output_dir(cfg, None)
if output_path is not None:
agg_json_path = Path(output_path)
agg_json_path.parent.mkdir(parents=True, exist_ok=True)
else:
agg_json_path = out_dir / f"results_{algo}_{len(seeds)}seeds.json"
logger.info(
"Running baseline %s on %d seed(s): %s (output -> %s)",
algo, len(seeds), seeds, agg_json_path,
)
all_seed_results: list[dict[str, Any]] = []
n_envs_per_id = int(cfg.baselines_n_envs_per_id)
for seed_idx, seed in enumerate(seeds):
logger.info(
"============================================================\n"
" %s seed %d (%d/%d)\n"
"============================================================",
algo.upper(), seed, seed_idx + 1, len(seeds),
)
_seed_everything(seed)
run_name = f"{algo}-multitask-seed{seed}"
log = _init_baseline_logger(cfg, run_name)
run_id = (
log._run.id # type: ignore[union-attr]
if log._use_wandb and log._run is not None
else f"local-{algo}-seed{seed}"
)
log_dir = str(out_dir / "logs" / run_id)
model_dir = str(out_dir / "models" / run_id)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
seed_results: dict[str, Any] = {"seed": seed}
try:
if algo == "dt":
model, dt_metrics = _train_dt(cfg, log, log_dir, seed)
seed_results.update(dt_metrics)
torch.save(
{
"model_state_dict": model.state_dict(),
"config": {
"n_actions": int(cfg.action_dim),
"embed_dim": int(cfg.baselines_dt_embed_dim),
"n_heads": int(cfg.baselines_dt_n_heads),
"n_layers": int(cfg.baselines_dt_n_layers),
"context_len": int(cfg.baselines_dt_context_len),
"max_ep_len": int(cfg.baselines_dt_max_ep_len),
},
},
f"{model_dir}/dt_final_seed{seed}.pt",
)
else:
# SB3 RL families and BC both need the parallel train env.
train_env_fns = [
_make_sb3_env_fn(env_id, cfg, log_dir)
for env_id in list(cfg.id_envs) * n_envs_per_id
]
train_env = SubprocVecEnv(train_env_fns)
try:
if algo == "bc":
policy, bc_metrics = _train_bc(
cfg, train_env, log, log_dir, seed,
)
seed_results.update(bc_metrics)
policy.save(f"{model_dir}/bc_final_seed{seed}")
else:
sb3_model = _build_sb3_model(
algo, train_env, cfg, seed,
tb_log_dir=str(out_dir / "tb" / run_id),
)
callbacks = _build_sb3_callbacks(
cfg, train_env, log_dir, model_dir,
)
logger.info(
"Training %s for %d env-steps across %d ID maps "
"(%d parallel envs)...",
algo.upper(),
int(cfg.total_timesteps),
len(cfg.id_envs),
train_env.num_envs,
)
sb3_model.learn(
total_timesteps=int(cfg.total_timesteps),
callback=callbacks,
)
sb3_model.save(f"{model_dir}/{algo}_final_seed{seed}")
finally:
train_env.close()
all_seed_results.append(seed_results)
finally:
log.finish()
logger.info("%s seed %d complete.", algo.upper(), seed)
agg = _aggregate(all_seed_results)
_print_aggregated(seeds, agg)
if agg:
_save_aggregated(agg_json_path, algo, seeds, all_seed_results, agg)
# Final summary write to the project Logger so the aggregated
# numbers land on a dedicated W&B run.
summary_run_name = f"{algo}-multitask-summary"
summary_log = _init_baseline_logger(cfg, summary_run_name)
try:
summary_payload: dict[str, float] = {}
for key, stats in agg.items():
summary_payload[f"summary/{key}/mean"] = stats["mean"]
summary_payload[f"summary/{key}/std"] = stats["std"]
summary_log.log_summary(summary_payload)
finally:
summary_log.finish()
logger.info("All %d seed(s) complete.", len(seeds))