remdm-craftax / src /diffusion /forward.py
MathisW78's picture
Upload COMP0258 demo bundle (code + diffusion/PPO checkpoints + ablation assets)
6140064 verified
"""Forward process: q(z_t | x_0) by independent per-token masking."""
from __future__ import annotations
import jax
import jax.numpy as jnp
def forward_process(
rng: jax.Array,
x_0: jnp.ndarray,
alpha_t: jnp.ndarray,
mask_id: int,
) -> jnp.ndarray:
"""Sample z_t ~ q(z_t | x_0). Each token stays with prob alpha_t, else MASK.
Args:
rng: PRNG key.
x_0: [B, H] int32, clean actions.
alpha_t: [B] or scalar, retention probability.
mask_id: MASK token index (= num_actions).
Returns:
z_t: [B, H] int32.
"""
keep = jax.random.uniform(rng, shape=x_0.shape)
alpha_t = jnp.reshape(alpha_t, (-1, 1))
return jnp.where(keep < alpha_t, x_0, jnp.array(mask_id, dtype=x_0.dtype))