remdm-minihack / src /models /denoiser.py
MathisW78's picture
Demo notebook payload (source + checkpoint + assets)
f748552 verified
"""Dual-stream denoising transformer for MiniHack.
Ported from minihack_reference/src/model.py. Architecture follows the
Craftax denoiser conventions (forward return format, obs-encoder pattern)
while using the MiniHack dual-stream design (local CNN + gated global
CNN + auxiliary goal head).
"""
from __future__ import annotations
import copy
import logging
import shutil
from types import SimpleNamespace
import torch
import torch.nn as nn
from torch import Tensor
logger = logging.getLogger(__name__)
class LocalDiffusionPlannerWithGlobal(nn.Module):
"""Dual-stream transformer for masked diffusion action planning.
Combines a local 9x9 glyph crop with a gated global 21x79 map
context. Produces action logits and an auxiliary staircase-coordinate
prediction.
Architecture:
Local stream: Embedding(6000,64) -> CNN(64->32->64) -> Linear -> 1 token
Global stream: Embedding(6000,32) -> CNN(32->32->64) -> Pool(2,4)
-> Linear -> 8 tokens, gated by sigmoid(learnable scalar)
Goal head: mean(global_tokens) -> MLP -> [B,2] (before gate)
Action stream: Embedding(14, n_embd) + timestep + position
Transformer: concat all -> TransformerEncoder -> last 64 tokens -> head
Args:
cfg: Config namespace with ``action_dim``, ``n_embd``, ``n_head``,
``n_layer``, ``n_global_tokens``, ``seq_len``,
``global_gate_init``, ``num_diffusion_steps``.
"""
def __init__(self, cfg: SimpleNamespace) -> None:
super().__init__()
action_dim = cfg.action_dim
n_embd = cfg.n_embd
n_head = cfg.n_head
n_layer = cfg.n_layer
n_global_tokens = cfg.n_global_tokens
seq_len = cfg.seq_len
assert n_embd % n_head == 0, (
f"n_embd ({n_embd}) must be divisible by n_head ({n_head})"
)
self.n_global_tokens = n_global_tokens
# ── Local stream: 9x9 crop -> 1 token ──────────────────────
self.embedding = nn.Embedding(6000, 64)
self.cnn = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1),
nn.GELU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.GELU(),
nn.Flatten(),
nn.Linear(64 * 9 * 9, n_embd),
)
# ── Action stream ──────────────────────────────────────────
self.action_emb = nn.Embedding(action_dim + 2, n_embd)
self.timestep_emb = nn.Embedding(
cfg.num_diffusion_steps, n_embd,
)
self.pos_emb = nn.Embedding(seq_len, n_embd)
# ── Transformer ───────────────────────────────────────────
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_embd,
nhead=n_head,
dim_feedforward=n_embd * 4,
dropout=getattr(cfg, "dropout", 0.0),
activation="gelu",
norm_first=True,
batch_first=True,
)
self.transformer = nn.TransformerEncoder(
encoder_layer, num_layers=n_layer, enable_nested_tensor=False,
)
self.head = nn.Linear(n_embd, action_dim)
# ── Global stream: 21x79 map -> 8 tokens ──────────────────
self.global_embedding = nn.Embedding(6000, 32)
self.global_cnn = nn.Sequential(
nn.Conv2d(32, 32, 5, stride=2, padding=2),
nn.GELU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.GELU(),
)
self.global_pool = nn.AdaptiveAvgPool2d((2, 4))
self.global_proj = nn.Linear(64, n_embd)
self.global_gate = nn.Parameter(
torch.tensor(cfg.global_gate_init)
)
# ── Auxiliary goal head (before gate) ──────────────────────
self.goal_head = nn.Sequential(
nn.Linear(n_embd, 128),
nn.GELU(),
nn.Linear(128, 2),
)
def forward(
self,
local_obs: Tensor,
global_obs: Tensor,
action_seq: Tensor,
t_discrete: int | Tensor,
) -> dict[str, Tensor]:
"""Forward pass producing action logits and goal prediction.
Args:
local_obs: Local glyph crop. Shape ``[B, 9, 9]``, int.
global_obs: Full glyph map. Shape ``[B, 21, 79]``, int.
action_seq: Noisy action sequence. Shape ``[B, seq_len]``, int.
t_discrete: Discrete timestep index (scalar int or ``[B]``).
Returns:
Dict with keys:
- ``"actions"``: ``[B, seq_len, action_dim]`` logits.
- ``"goal_pred"``: ``[B, 2]`` normalised staircase coords.
"""
B, Seq = action_seq.shape
device = local_obs.device
# Local stream -> [B, 1, n_embd]
x_local = self.embedding(local_obs) # [B, 9, 9, 64]
x_local = x_local.permute(0, 3, 1, 2) # [B, 64, 9, 9]
local_token = self.cnn(x_local).unsqueeze(1) # [B, 1, n_embd]
# Global stream -> [B, 8, n_embd]
x_global = self.global_embedding(global_obs) # [B, 21, 79, 32]
x_global = x_global.permute(0, 3, 1, 2) # [B, 32, 21, 79]
gf = self.global_cnn(x_global) # [B, 64, H', W']
gf = self.global_pool(gf) # [B, 64, 2, 4]
global_tokens = gf.permute(0, 2, 3, 1) # [B, 2, 4, 64]
global_tokens = global_tokens.reshape(
B, self.n_global_tokens, -1
) # [B, 8, 64]
global_tokens = self.global_proj(global_tokens) # [B, 8, n_embd]
# Aux goal head (before gate for direct gradient to CNN)
goal_pred = self.goal_head(
global_tokens.mean(dim=1)
) # [B, 2]
# Apply gate
gate = torch.sigmoid(self.global_gate)
global_tokens = global_tokens * gate # [B, 8, n_embd]
# Action stream -> [B, seq_len, n_embd]
positions = torch.arange(
Seq, device=device,
).unsqueeze(0).expand(B, -1) # [B, seq_len]
if isinstance(t_discrete, int):
t_tensor = torch.full(
(B,), t_discrete, dtype=torch.long, device=device,
)
else:
t_tensor = t_discrete.long().to(device)
seq_emb = (
self.action_emb(action_seq)
+ self.timestep_emb(t_tensor).unsqueeze(1)
+ self.pos_emb(positions)
) # [B, seq_len, n_embd]
# Concatenate: [local(1), global(8), actions(seq_len)]
x = torch.cat(
[local_token, global_tokens, seq_emb], dim=1,
) # [B, 1+8+seq_len, n_embd]
# Transformer
out = self.transformer(x) # [B, 1+8+seq_len, n_embd]
# Take last seq_len tokens for action predictions
n_prefix = 1 + self.n_global_tokens
action_logits = self.head(
out[:, n_prefix:, :]
) # [B, seq_len, action_dim]
return {"actions": action_logits, "goal_pred": goal_pred}
class LocalDiffusionPlanner(nn.Module):
"""Local-only ablation model (no global stream, no goal head).
Args:
cfg: Config namespace.
"""
def __init__(self, cfg: SimpleNamespace) -> None:
super().__init__()
action_dim = cfg.action_dim
n_embd = cfg.n_embd
seq_len = cfg.seq_len
self.embedding = nn.Embedding(6000, 64)
self.cnn = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1),
nn.GELU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.GELU(),
nn.Flatten(),
nn.Linear(64 * 9 * 9, n_embd),
)
self.action_emb = nn.Embedding(action_dim + 2, n_embd)
self.timestep_emb = nn.Embedding(cfg.num_diffusion_steps, n_embd)
self.pos_emb = nn.Embedding(seq_len, n_embd)
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_embd,
nhead=cfg.n_head,
dim_feedforward=n_embd * 4,
dropout=getattr(cfg, "dropout", 0.0),
activation="gelu",
norm_first=True,
batch_first=True,
)
self.transformer = nn.TransformerEncoder(
encoder_layer, num_layers=cfg.n_layer,
)
self.head = nn.Linear(n_embd, action_dim)
def forward(
self,
local_obs: Tensor,
global_obs: Tensor,
action_seq: Tensor,
t_discrete: int | Tensor,
) -> dict[str, Tensor]:
"""Forward pass (ignores global_obs).
Args:
local_obs: ``[B, 9, 9]`` int.
global_obs: ``[B, 21, 79]`` int (ignored).
action_seq: ``[B, seq_len]`` int.
t_discrete: Timestep index.
Returns:
Dict with ``"actions"`` key only (no goal_pred).
"""
B, Seq = action_seq.shape
device = local_obs.device
x_state = self.embedding(local_obs).permute(0, 3, 1, 2)
state_emb = self.cnn(x_state).unsqueeze(1) # [B, 1, n_embd]
positions = torch.arange(
Seq, device=device,
).unsqueeze(0).expand(B, -1)
if isinstance(t_discrete, int):
t_tensor = torch.full(
(B,), t_discrete, dtype=torch.long, device=device,
)
else:
t_tensor = t_discrete.long().to(device)
seq_emb = (
self.action_emb(action_seq)
+ self.timestep_emb(t_tensor).unsqueeze(1)
+ self.pos_emb(positions)
)
x = torch.cat([state_emb, seq_emb], dim=1)
out = self.transformer(x)
return {"actions": self.head(out[:, 1:, :])}
# ── Factory ──────────────────────────────────────────────────────────
def make_model(cfg: SimpleNamespace) -> nn.Module:
"""Instantiate the default MiniHack denoising model.
Args:
cfg: Config namespace.
Returns:
``LocalDiffusionPlannerWithGlobal`` instance.
"""
return LocalDiffusionPlannerWithGlobal(cfg)
def _has_c_compiler() -> bool:
"""Check whether a C compiler is reachable by Triton.
Checks the ``CC`` env var (set by conda activation scripts),
then falls back to ``cc`` and ``gcc`` on ``PATH``.
"""
import os
cc_env = os.environ.get("CC")
if cc_env and shutil.which(cc_env):
return True
return shutil.which("cc") is not None or shutil.which("gcc") is not None
def try_compile(model: nn.Module, cfg: SimpleNamespace) -> nn.Module:
"""Wrap *model* with ``torch.compile`` if enabled and a C compiler exists.
Falls back to the uncompiled model when ``torch.compile`` is
unavailable or Triton cannot find a C compiler (common on managed
GPU nodes that lack ``gcc``/``cc``).
Args:
model: The raw (uncompiled) model.
cfg: Config namespace; reads ``torch_compile`` bool.
Returns:
Compiled model, or *model* unchanged on fallback.
"""
if not getattr(cfg, "torch_compile", False):
return model
if not hasattr(torch, "compile"):
return model
if not _has_c_compiler():
logger.warning(
"torch.compile requested but no C compiler found "
"(CC env var, cc, gcc); falling back to eager mode"
)
return model
logger.info("Compiling model with torch.compile")
return torch.compile(model, mode="default") # type: ignore[return-value]
# ── EMA ──────────────────────────────────────────────────────────────
class ModelEMA:
"""Exponential moving average of model parameters.
Maintains a shadow copy of parameters updated as
``theta_ema <- decay * theta_ema + (1 - decay) * theta``.
Args:
model: Source model.
decay: EMA decay factor (default 0.999).
"""
def __init__(self, model: nn.Module, decay: float = 0.999) -> None:
self._decay = decay
self._shadow: dict[str, Tensor] = {}
for name, param in model.named_parameters():
self._shadow[name] = param.data.clone()
@torch.no_grad()
def update(self, model: nn.Module) -> None:
"""Update shadow parameters from *model*.
Args:
model: Source model whose parameters are blended in.
"""
for name, param in model.named_parameters():
self._shadow[name].mul_(self._decay).add_(
param.data, alpha=1.0 - self._decay,
)
def apply_to(self, model: nn.Module) -> None:
"""Copy shadow parameters into *model* (for inference).
Args:
model: Target model to overwrite.
"""
for name, param in model.named_parameters():
param.data.copy_(self._shadow[name])
def state_dict(self) -> dict[str, Tensor]:
"""Return shadow parameter dict for serialisation.
Returns:
Dict mapping parameter names to EMA tensors.
"""
return {k: v.clone() for k, v in self._shadow.items()}
def load_state_dict(self, sd: dict[str, Tensor]) -> None:
"""Restore shadow parameters from *sd*.
Args:
sd: State dict from a prior ``state_dict()`` call.
"""
for k, v in sd.items():
if k in self._shadow:
self._shadow[k].copy_(v)
def parameters(self):
"""Iterate over shadow parameter tensors.
Yields:
EMA parameter tensors.
"""
yield from self._shadow.values()
def make_eval_model(self, model: nn.Module) -> nn.Module:
"""Return a deep copy of *model* with EMA weights applied.
Args:
model: Template model (architecture).
Returns:
New model with shadow parameters.
"""
eval_model = copy.deepcopy(model)
self.apply_to(eval_model)
eval_model.eval()
return eval_model