remdm-craftax / src /models /denoiser.py
MathisW78's picture
Upload COMP0258 demo bundle (code + diffusion/PPO checkpoints + ablation assets)
6140064 verified
"""Denoising Transformer for masked discrete diffusion planning.
Architecture: obs MLP encoder + sinusoidal time embedding + bidirectional
transformer. Two prefix tokens (obs, time) precede the action sequence.
"""
from __future__ import annotations
import numpy as np
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
_INIT = orthogonal(np.sqrt(2))
_INIT_SMALL = orthogonal(0.01)
_BIAS = constant(0.0)
class SinusoidalPosEmbed(nn.Module):
"""Sinusoidal embedding for continuous timesteps or integer positions."""
dim: int
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
half = self.dim // 2
freqs = jnp.exp(-jnp.log(10_000.0) * jnp.arange(half) / half)
angles = x[..., None] * freqs
emb = jnp.concatenate([jnp.sin(angles), jnp.cos(angles)], axis=-1)
if self.dim % 2 == 1:
emb = jnp.concatenate([emb, jnp.zeros_like(emb[..., :1])], axis=-1)
return emb
class TransformerBlock(nn.Module):
"""Pre-norm transformer: LN -> MHA -> res -> LN -> FFN -> res."""
d_model: int
n_heads: int
d_ff: int
dropout_rate: float = 0.1
deterministic: bool = True
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
h = nn.LayerNorm()(x)
h = nn.MultiHeadDotProductAttention(
num_heads=self.n_heads, kernel_init=_INIT, deterministic=self.deterministic,
)(h, h)
h = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(h)
x = x + h
h = nn.LayerNorm()(x)
h = nn.Dense(self.d_ff, kernel_init=_INIT, bias_init=_BIAS)(h)
h = nn.gelu(h)
h = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(h)
h = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(h)
return x + h
class DenoisingTransformer(nn.Module):
"""Denoising transformer for masked discrete diffusion planning.
Input: (obs [B, D], noisy_actions [B, H], timestep [B])
Output: logits [B, H, num_actions] (no MASK logit).
"""
num_actions: int
plan_horizon: int
d_model: int = 256
n_heads: int = 4
n_layers: int = 4
d_ff: int = 512
obs_encoder_layers: int = 2
obs_encoder_width: int = 512
dropout_rate: float = 0.1
@nn.compact
def __call__(
self,
obs: jnp.ndarray,
noisy_actions: jnp.ndarray,
timestep: jnp.ndarray,
deterministic: bool = True,
) -> jnp.ndarray:
B = obs.shape[0]
vocab = self.num_actions + 1 # +1 for MASK token
# Observation encoder
h = nn.Dense(self.obs_encoder_width, kernel_init=_INIT, bias_init=_BIAS)(obs)
h = nn.LayerNorm()(h)
h = nn.relu(h)
for _ in range(self.obs_encoder_layers - 1):
h = nn.Dense(self.obs_encoder_width, kernel_init=_INIT, bias_init=_BIAS)(h)
h = nn.relu(h)
obs_tok = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(h)[:, None, :]
# Time embedding
t = timestep.reshape(B)
t_emb = SinusoidalPosEmbed(self.d_model)(t)
t_emb = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(t_emb)
t_emb = nn.gelu(t_emb)
t_tok = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(t_emb)[:, None, :]
# Action token embedding
act_emb = nn.Embed(num_embeddings=vocab, features=self.d_model)(noisy_actions)
# Assemble sequence: [obs, time, actions]
seq = jnp.concatenate([obs_tok, t_tok, act_emb], axis=1)
seq_len = 2 + self.plan_horizon
pos_emb = SinusoidalPosEmbed(self.d_model)(jnp.arange(seq_len))
seq = seq + pos_emb[None, :, :]
# Transformer
for _ in range(self.n_layers):
seq = TransformerBlock(
d_model=self.d_model, n_heads=self.n_heads, d_ff=self.d_ff,
dropout_rate=self.dropout_rate, deterministic=deterministic,
)(seq)
seq = nn.LayerNorm()(seq)
# Output logits over real actions (skip 2 prefix tokens)
return nn.Dense(self.num_actions, kernel_init=_INIT_SMALL, bias_init=_BIAS)(
seq[:, 2:, :]
)