| """Denoising Transformer for masked discrete diffusion planning. |
| |
| Architecture: obs MLP encoder + sinusoidal time embedding + bidirectional |
| transformer. Two prefix tokens (obs, time) precede the action sequence. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import jax.numpy as jnp |
| import flax.linen as nn |
| from flax.linen.initializers import constant, orthogonal |
|
|
| _INIT = orthogonal(np.sqrt(2)) |
| _INIT_SMALL = orthogonal(0.01) |
| _BIAS = constant(0.0) |
|
|
|
|
| class SinusoidalPosEmbed(nn.Module): |
| """Sinusoidal embedding for continuous timesteps or integer positions.""" |
|
|
| dim: int |
|
|
| @nn.compact |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: |
| half = self.dim // 2 |
| freqs = jnp.exp(-jnp.log(10_000.0) * jnp.arange(half) / half) |
| angles = x[..., None] * freqs |
| emb = jnp.concatenate([jnp.sin(angles), jnp.cos(angles)], axis=-1) |
| if self.dim % 2 == 1: |
| emb = jnp.concatenate([emb, jnp.zeros_like(emb[..., :1])], axis=-1) |
| return emb |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Pre-norm transformer: LN -> MHA -> res -> LN -> FFN -> res.""" |
|
|
| d_model: int |
| n_heads: int |
| d_ff: int |
| dropout_rate: float = 0.1 |
| deterministic: bool = True |
|
|
| @nn.compact |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: |
| h = nn.LayerNorm()(x) |
| h = nn.MultiHeadDotProductAttention( |
| num_heads=self.n_heads, kernel_init=_INIT, deterministic=self.deterministic, |
| )(h, h) |
| h = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(h) |
| x = x + h |
|
|
| h = nn.LayerNorm()(x) |
| h = nn.Dense(self.d_ff, kernel_init=_INIT, bias_init=_BIAS)(h) |
| h = nn.gelu(h) |
| h = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(h) |
| h = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(h) |
| return x + h |
|
|
|
|
| class DenoisingTransformer(nn.Module): |
| """Denoising transformer for masked discrete diffusion planning. |
| |
| Input: (obs [B, D], noisy_actions [B, H], timestep [B]) |
| Output: logits [B, H, num_actions] (no MASK logit). |
| """ |
|
|
| num_actions: int |
| plan_horizon: int |
| d_model: int = 256 |
| n_heads: int = 4 |
| n_layers: int = 4 |
| d_ff: int = 512 |
| obs_encoder_layers: int = 2 |
| obs_encoder_width: int = 512 |
| dropout_rate: float = 0.1 |
|
|
| @nn.compact |
| def __call__( |
| self, |
| obs: jnp.ndarray, |
| noisy_actions: jnp.ndarray, |
| timestep: jnp.ndarray, |
| deterministic: bool = True, |
| ) -> jnp.ndarray: |
| B = obs.shape[0] |
| vocab = self.num_actions + 1 |
|
|
| |
| h = nn.Dense(self.obs_encoder_width, kernel_init=_INIT, bias_init=_BIAS)(obs) |
| h = nn.LayerNorm()(h) |
| h = nn.relu(h) |
| for _ in range(self.obs_encoder_layers - 1): |
| h = nn.Dense(self.obs_encoder_width, kernel_init=_INIT, bias_init=_BIAS)(h) |
| h = nn.relu(h) |
| obs_tok = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(h)[:, None, :] |
|
|
| |
| t = timestep.reshape(B) |
| t_emb = SinusoidalPosEmbed(self.d_model)(t) |
| t_emb = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(t_emb) |
| t_emb = nn.gelu(t_emb) |
| t_tok = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(t_emb)[:, None, :] |
|
|
| |
| act_emb = nn.Embed(num_embeddings=vocab, features=self.d_model)(noisy_actions) |
|
|
| |
| seq = jnp.concatenate([obs_tok, t_tok, act_emb], axis=1) |
| seq_len = 2 + self.plan_horizon |
| pos_emb = SinusoidalPosEmbed(self.d_model)(jnp.arange(seq_len)) |
| seq = seq + pos_emb[None, :, :] |
|
|
| |
| for _ in range(self.n_layers): |
| seq = TransformerBlock( |
| d_model=self.d_model, n_heads=self.n_heads, d_ff=self.d_ff, |
| dropout_rate=self.dropout_rate, deterministic=deterministic, |
| )(seq) |
| seq = nn.LayerNorm()(seq) |
|
|
| |
| return nn.Dense(self.num_actions, kernel_init=_INIT_SMALL, bias_init=_BIAS)( |
| seq[:, 2:, :] |
| ) |
|
|