| """Forward process: q(z_t | x_0) by independent per-token masking.""" | |
| from __future__ import annotations | |
| import jax | |
| import jax.numpy as jnp | |
| def forward_process( | |
| rng: jax.Array, | |
| x_0: jnp.ndarray, | |
| alpha_t: jnp.ndarray, | |
| mask_id: int, | |
| ) -> jnp.ndarray: | |
| """Sample z_t ~ q(z_t | x_0). Each token stays with prob alpha_t, else MASK. | |
| Args: | |
| rng: PRNG key. | |
| x_0: [B, H] int32, clean actions. | |
| alpha_t: [B] or scalar, retention probability. | |
| mask_id: MASK token index (= num_actions). | |
| Returns: | |
| z_t: [B, H] int32. | |
| """ | |
| keep = jax.random.uniform(rng, shape=x_0.shape) | |
| alpha_t = jnp.reshape(alpha_t, (-1, 1)) | |
| return jnp.where(keep < alpha_t, x_0, jnp.array(mask_id, dtype=x_0.dtype)) | |