File size: 765 Bytes
6140064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""Forward process: q(z_t | x_0) by independent per-token masking."""

from __future__ import annotations

import jax
import jax.numpy as jnp


def forward_process(
    rng: jax.Array,
    x_0: jnp.ndarray,
    alpha_t: jnp.ndarray,
    mask_id: int,
) -> jnp.ndarray:
    """Sample z_t ~ q(z_t | x_0).  Each token stays with prob alpha_t, else MASK.

    Args:
        rng:     PRNG key.
        x_0:     [B, H] int32, clean actions.
        alpha_t: [B] or scalar, retention probability.
        mask_id: MASK token index (= num_actions).

    Returns:
        z_t: [B, H] int32.
    """
    keep = jax.random.uniform(rng, shape=x_0.shape)
    alpha_t = jnp.reshape(alpha_t, (-1, 1))
    return jnp.where(keep < alpha_t, x_0, jnp.array(mask_id, dtype=x_0.dtype))