MathisW78's picture
Upload COMP0258 demo bundle (code + checkpoint + ablation assets)
a4acc07 verified
"""Reverse diffusion sampling with ReMDM remasking (Wang et al.).
Strategies (Section 4.1):
rescale: sigma = eta * sigma_max
cap: sigma = min(eta, sigma_max)
conf: per-token confidence-based remasking
Loop mode (Section 4.2, Algorithm 3):
Phase 1: standard MDLM decode, t in [1, t_on]
Phase 2: constant alpha(t_on), remasking active
Phase 3: standard MDLM decode, t in [t_off, 0]
"""
from __future__ import annotations
from typing import Any, Callable, Optional
import jax
import jax.numpy as jnp
from .schedules import ScheduleFn
ModelApplyFn = Callable[
[Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[Any]], jnp.ndarray
]
# ---------------------------------------------------------------------------
# Remasking sigma computation
# ---------------------------------------------------------------------------
def _sigma_max(alpha_t: jnp.ndarray, alpha_s: jnp.ndarray) -> jnp.ndarray:
"""sigma_max = min(1, (1 - alpha_s) / alpha_t). [Eq. 7]"""
return jnp.minimum(1.0, (1.0 - alpha_s) / jnp.maximum(alpha_t, 1e-8))
def sigma_rescale(alpha_t, alpha_s, eta):
return eta * _sigma_max(alpha_t, alpha_s)
def sigma_cap(alpha_t, alpha_s, eta):
return jnp.minimum(eta, _sigma_max(alpha_t, alpha_s))
def sigma_conf(alpha_t, alpha_s, eta, psi, is_unmasked):
"""Per-token confidence remasking. Safe against all-masked rows."""
base = eta * _sigma_max(alpha_t, alpha_s)
any_unmasked = jnp.any(is_unmasked, axis=-1, keepdims=True)
neg_psi = jnp.where(is_unmasked, -psi, -jnp.inf)
safe_neg_psi = jnp.where(any_unmasked, neg_psi, 0.0)
eta_conf = jax.nn.softmax(safe_neg_psi, axis=-1)
return jnp.where(is_unmasked, eta_conf * base, 0.0)
_SIGMA_FNS = {
"rescale": lambda a_t, a_s, eta, *_: sigma_rescale(a_t, a_s, eta),
"cap": lambda a_t, a_s, eta, *_: sigma_cap(a_t, a_s, eta),
"conf": lambda a_t, a_s, eta, psi, unm: sigma_conf(a_t, a_s, eta, psi, unm),
}
# ---------------------------------------------------------------------------
# Decoding helpers
# ---------------------------------------------------------------------------
def _nucleus_sample(rng, logits, top_p):
"""Top-p sampling from [B, H, V] logits -> [B, H] int32."""
probs = jax.nn.softmax(logits, axis=-1)
idx = jnp.argsort(-probs, axis=-1)
sorted_p = jnp.take_along_axis(probs, idx, axis=-1)
cum = jnp.cumsum(sorted_p, axis=-1)
cutoff = cum - sorted_p
sorted_p = jnp.where(cutoff >= top_p, 0.0, sorted_p)
sorted_p = sorted_p / jnp.maximum(sorted_p.sum(axis=-1, keepdims=True), 1e-12)
B, H, V = logits.shape
flat = sorted_p.reshape(B * H, V)
tokens = jax.random.categorical(rng, jnp.log(flat + 1e-12)).reshape(B, H)
return jnp.take_along_axis(idx, tokens[..., None], axis=-1).squeeze(-1)
def _decode(rng, logits, temperature, top_p):
"""Sample tokens from logits. Argmax only when temperature <= 0."""
if top_p is not None:
return _nucleus_sample(rng, logits / jnp.maximum(temperature, 1e-8), top_p)
if temperature > 1e-8:
B, H, V = logits.shape
scaled = logits / temperature
return jax.random.categorical(rng, scaled.reshape(-1, V)).reshape(B, H)
return jnp.argmax(logits, axis=-1)
# ---------------------------------------------------------------------------
# Reverse sampling
# ---------------------------------------------------------------------------
def sample_plan(
model_apply: ModelApplyFn,
params: Any,
rng: jax.Array,
obs: jnp.ndarray,
num_actions: int,
plan_horizon: int,
num_steps: int,
schedule_fn: ScheduleFn,
remask_strategy: str = "cap",
eta: float = 0.5,
use_loop: bool = False,
t_on: float = 0.55,
t_off: float = 0.05,
temperature: float = 1.0,
top_p: Optional[float] = None,
) -> jnp.ndarray:
"""Generate an action plan via reverse diffusion with ReMDM remasking.
Returns:
actions: [B, H] int32.
"""
B = obs.shape[0]
mask_id = num_actions
mask_val = jnp.array(mask_id, dtype=jnp.int32)
if remask_strategy not in _SIGMA_FNS:
raise ValueError(f"Unknown strategy {remask_strategy!r}. Options: {list(_SIGMA_FNS)}")
get_sigma = _SIGMA_FNS[remask_strategy]
# Phase allocation for loop mode
if use_loop:
f1, f3 = 1.0 - t_on, t_off
denom = f1 + f3 + (t_on - t_off)
n1 = max(int(round(num_steps * f1 / denom)), 1)
n3 = max(int(round(num_steps * f3 / denom)), 1)
n2 = max(num_steps - n1 - n3, 1)
else:
n1, n2, n3 = num_steps, 0, 0
alpha_loop = schedule_fn(jnp.array(t_on))
z_init = jnp.full((B, plan_horizon), mask_id, dtype=jnp.int32)
psi_init = jnp.full((B, plan_horizon), jnp.inf)
# ------------------------------------------------------------------
# Core denoising step (ReMDM Eq. 6)
# ------------------------------------------------------------------
def _step(carry, _unused, t_val, alpha_t, alpha_s, sigma_on):
z, rng, psi = carry
rng, s_rng, u_rng, r_rng = jax.random.split(rng, 4)
t_inp = jnp.full((B,), t_val)
logits = model_apply(params, obs, z, t_inp, None)
x_hat = _decode(s_rng, logits, temperature, top_p)
is_masked = z == mask_id
is_unmasked = ~is_masked
sigma = get_sigma(alpha_t, alpha_s, eta, psi, is_unmasked)
sigma = jnp.broadcast_to(sigma, z.shape)
sigma = jnp.where(sigma_on, sigma, 0.0)
# Masked -> unmask probability
denom = jnp.maximum(1.0 - alpha_t, 1e-8)
p_unmask = jnp.clip((alpha_s - (1.0 - sigma) * alpha_t) / denom, 0.0, 1.0)
do_unmask = is_masked & (jax.random.uniform(u_rng, z.shape) < p_unmask)
do_remask = is_unmasked & (jax.random.uniform(r_rng, z.shape) < sigma)
z_new = jnp.where(do_unmask, x_hat, z)
z_new = jnp.where(do_remask, mask_val, z_new)
# Update confidence history
probs = jax.nn.softmax(logits, axis=-1)
decode_prob = jnp.take_along_axis(probs, x_hat[..., None], axis=-1).squeeze(-1)
psi_new = jnp.where(do_unmask, decode_prob, psi)
psi_new = jnp.where(do_remask, jnp.inf, psi_new)
return (z_new, rng, psi_new), None
# ------------------------------------------------------------------
# Phase functions
# ------------------------------------------------------------------
def _phase1_step(carry, idx):
t = 1.0 - idx * (1.0 - t_on) / n1
s = jnp.maximum(1.0 - (idx + 1) * (1.0 - t_on) / n1, t_on)
return _step(carry, idx, t, schedule_fn(t), schedule_fn(s), False)
def _phase2_step(carry, idx):
return _step(carry, idx, t_on, alpha_loop, alpha_loop, True)
def _phase3_step(carry, idx):
t = t_off - idx * t_off / n3
s = jnp.maximum(t_off - (idx + 1) * t_off / n3, 0.0)
return _step(carry, idx, t, schedule_fn(t), schedule_fn(s), False)
def _simple_step(carry, idx):
t = (num_steps - idx) / num_steps
s = jnp.maximum((num_steps - idx - 1) / num_steps, 0.0)
return _step(carry, idx, t, schedule_fn(t), schedule_fn(s), True)
# ------------------------------------------------------------------
# Run
# ------------------------------------------------------------------
carry = (z_init, rng, psi_init)
if use_loop:
carry, _ = jax.lax.scan(_phase1_step, carry, jnp.arange(n1))
carry, _ = jax.lax.scan(_phase2_step, carry, jnp.arange(n2))
if n3 > 0:
carry, _ = jax.lax.scan(_phase3_step, carry, jnp.arange(n3))
else:
carry, _ = jax.lax.scan(_simple_step, carry, jnp.arange(num_steps))
z_final = carry[0]
# Final greedy cleanup for any remaining masks
final_logits = model_apply(params, obs, z_final, jnp.zeros((B,)), None)
fallback = jnp.argmax(final_logits, axis=-1)
return jnp.where(z_final == mask_id, fallback, z_final)
# ---------------------------------------------------------------------------
# Inpainting sampler (MPC / historical prefix)
# ---------------------------------------------------------------------------
def sample_plan_inpainting(
apply_fn: ModelApplyFn,
params: Any,
rng: jax.Array,
obs: jnp.ndarray,
history: jnp.ndarray,
hist_len: jnp.ndarray,
num_actions: int,
plan_horizon: int,
diffusion_steps: int,
temperature: float,
top_p: Optional[float],
) -> jnp.ndarray:
"""Diffusion sampling with a locked historical prefix (inpainting).
Positions ``0 .. hist_len[b] - 1`` are fixed to the values in ``history``
for each batch element ``b``; the remainder are diffused freely.
Args:
apply_fn: Model apply closure (eval mode, no dropout).
params: Model parameter pytree.
rng: PRNG key.
obs: ``[B, obs_dim]`` conditioning observations.
history: ``[B, plan_horizon]`` int32 prefix of executed actions.
hist_len: ``[B]`` int32 number of valid prefix tokens per element.
num_actions: Size of the real action vocabulary (mask token = ``num_actions``).
plan_horizon: Total sequence length.
diffusion_steps: Number of denoising iterations.
temperature: Softmax temperature for token sampling.
top_p: Nucleus-sampling threshold; ``None`` disables nucleus filtering.
Returns:
``[B, plan_horizon]`` int32 completed action plan.
"""
B = obs.shape[0]
mask_id = num_actions
def _step(carry, step):
seq, rng = carry
rng, model_rng, sample_rng, remask_rng = jax.random.split(rng, 4)
ratio = step / diffusion_steps
t_tensor = jnp.full((B,), 1.0 - ratio)
logits = apply_fn(params, obs, seq, t_tensor, model_rng) / jnp.maximum(temperature, 1e-8)
# Optional nucleus filtering
if top_p is not None:
probs = jax.nn.softmax(logits, axis=-1)
sorted_idx = jnp.argsort(-probs, axis=-1)
sorted_p = jnp.take_along_axis(probs, sorted_idx, axis=-1)
cutoff = jnp.cumsum(sorted_p, axis=-1) - sorted_p
inv_idx = jnp.argsort(sorted_idx, axis=-1)
nucleus_mask = jnp.take_along_axis(cutoff >= top_p, inv_idx, axis=-1)
logits = jnp.where(nucleus_mask, -jnp.inf, logits)
preds = jax.random.categorical(sample_rng, logits, axis=-1)
conf = jnp.take_along_axis(
jax.nn.softmax(logits, axis=-1), preds[..., None], axis=-1,
).squeeze(-1)
# Keep top-(ratio * H) most confident predictions unmasked
num_unmask = jnp.maximum(1, (plan_horizon * ratio).astype(jnp.int32))
sorted_conf = jnp.sort(conf, axis=-1)[..., ::-1]
thresh = sorted_conf[jnp.arange(B), num_unmask - 1]
seq_new = jnp.where(conf < thresh[:, None], mask_id, preds)
# Light ReMDM-style remasking
remask_prob = 0.15 * (1.0 - ratio)
do_remask = (
(jax.random.uniform(remask_rng, seq_new.shape) < remask_prob)
& (seq_new != mask_id)
)
seq_new = jnp.where(do_remask, mask_id, seq_new)
# Lock historical prefix
pos = jnp.broadcast_to(jnp.arange(plan_horizon)[None, :], (B, plan_horizon))
seq_new = jnp.where(pos < hist_len[:, None], history, seq_new)
return (seq_new, rng), None
# Initialise: history locked, remainder fully masked
init_seq = jnp.full((B, plan_horizon), mask_id, dtype=jnp.int32)
pos = jnp.broadcast_to(jnp.arange(plan_horizon)[None, :], (B, plan_horizon))
init_seq = jnp.where(pos < hist_len[:, None], history, init_seq)
(final_seq, _), _ = jax.lax.scan(
_step, (init_seq, rng), jnp.arange(1, diffusion_steps + 1),
)
return final_seq