File size: 765 Bytes
6140064 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | """Forward process: q(z_t | x_0) by independent per-token masking."""
from __future__ import annotations
import jax
import jax.numpy as jnp
def forward_process(
rng: jax.Array,
x_0: jnp.ndarray,
alpha_t: jnp.ndarray,
mask_id: int,
) -> jnp.ndarray:
"""Sample z_t ~ q(z_t | x_0). Each token stays with prob alpha_t, else MASK.
Args:
rng: PRNG key.
x_0: [B, H] int32, clean actions.
alpha_t: [B] or scalar, retention probability.
mask_id: MASK token index (= num_actions).
Returns:
z_t: [B, H] int32.
"""
keep = jax.random.uniform(rng, shape=x_0.shape)
alpha_t = jnp.reshape(alpha_t, (-1, 1))
return jnp.where(keep < alpha_t, x_0, jnp.array(mask_id, dtype=x_0.dtype))
|