| """MDLM ELBO loss for masked discrete diffusion training.""" |
|
|
| from __future__ import annotations |
| from typing import Any, Callable, Optional |
|
|
| import jax |
| import jax.numpy as jnp |
|
|
| from .forward import forward_process |
| from .schedules import ScheduleFn |
|
|
| ModelApplyFn = Callable[ |
| [Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[Any]], jnp.ndarray |
| ] |
|
|
| _MAX_WEIGHT: float = 1000.0 |
| _EPS: float = 1e-5 |
|
|
|
|
| def compute_loss( |
| model_apply: ModelApplyFn, |
| params: Any, |
| rng: jax.Array, |
| x_0: jnp.ndarray, |
| obs: jnp.ndarray, |
| valid: jnp.ndarray, |
| num_actions: int, |
| schedule_fn: ScheduleFn, |
| schedule_deriv_fn: ScheduleFn, |
| sigma_t: float = 0.0, |
| label_smoothing: float = 0.0, |
| advantages: Optional[jnp.ndarray] = None, |
| t_min: float | jax.Array = _EPS, |
| t_max: float | jax.Array = 1.0, |
| ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: |
| """Continuous-time ELBO loss on masked positions only. |
| |
| Args: |
| model_apply: fn(params, obs, z_t, t, rng) -> logits [B, H, V]. |
| params: Model parameters. |
| rng: PRNG key. |
| x_0: [B, H] int32, ground-truth actions. |
| obs: [B, obs_dim] float32, observations. |
| valid: [B] bool/float, whether each sample is valid. |
| num_actions: Size of real action vocabulary. |
| schedule_fn: alpha(t). |
| schedule_deriv_fn: d(alpha)/dt (analytic). |
| sigma_t: Remasking correction for ELBO weight (0 = standard MDLM). |
| label_smoothing: Smoothing epsilon (0 = exact ELBO targets). |
| advantages: Optional [B] per-sample weights. |
| t_min: Lower bound for uniform t sampling (default: _EPS). |
| t_max: Upper bound for uniform t sampling (default: 1.0). |
| |
| Returns: |
| (loss, info_dict). |
| """ |
| B = x_0.shape[0] |
| mask_id = num_actions |
| rng, t_rng, mask_rng, drop_rng = jax.random.split(rng, 4) |
|
|
| |
| t = jax.random.uniform(t_rng, (B,), minval=t_min, maxval=t_max) |
| alpha_t = schedule_fn(t) |
|
|
| |
| neg_alpha_dot = -schedule_deriv_fn(t) |
| weight = (1.0 - sigma_t) * neg_alpha_dot / jnp.maximum(1.0 - alpha_t, _EPS) |
| weight = jnp.minimum(weight, _MAX_WEIGHT) |
|
|
| |
| z_t = forward_process(mask_rng, x_0, alpha_t, mask_id) |
|
|
| |
| logits = model_apply(params, obs, z_t, t, drop_rng) |
|
|
| |
| is_masked = (z_t == mask_id).astype(jnp.float32) |
| valid_masked = is_masked * valid[:, None].astype(jnp.float32) |
|
|
| targets = jax.nn.one_hot(x_0, num_actions) |
| if label_smoothing > 0: |
| targets = (1.0 - label_smoothing) * targets + label_smoothing / num_actions |
|
|
| log_probs = jax.nn.log_softmax(logits, axis=-1) |
| ce = -jnp.sum(targets * log_probs, axis=-1) |
|
|
| n_masked = jnp.maximum(valid_masked.sum(axis=-1), 1.0) |
| per_sample = weight * (ce * valid_masked).sum(axis=-1) / n_masked |
|
|
| if advantages is not None: |
| per_sample = per_sample * jax.lax.stop_gradient(advantages) |
|
|
| loss = jnp.mean(per_sample) |
|
|
| |
| preds = jnp.argmax(logits, axis=-1) |
| correct = (preds == x_0).astype(jnp.float32) |
| acc = jnp.sum(correct * valid_masked) / jnp.maximum(valid_masked.sum(), 1.0) |
|
|
| t_bins = jnp.array([0.33, 0.66]) |
| t_lo = (t < t_bins[0])[:, None] |
| t_mi = ((t >= t_bins[0]) & (t <= t_bins[1]))[:, None] |
| t_hi = (t > t_bins[1])[:, None] |
|
|
| def _binned_acc(mask): |
| m = valid_masked * mask |
| return jnp.sum(correct * m) / jnp.maximum(m.sum(), 1.0) |
|
|
| info = { |
| "loss": loss, |
| "unweighted_loss": jnp.mean((ce * valid_masked).sum(axis=-1) / n_masked), |
| "mean_t": jnp.mean(t), |
| "frac_masked": jnp.mean(is_masked), |
| "accuracy": acc, |
| "acc_t_low": _binned_acc(t_lo), |
| "acc_t_mid": _binned_acc(t_mi), |
| "acc_t_high": _binned_acc(t_hi), |
| } |
| if advantages is not None: |
| info["adv_mean"] = jnp.mean(advantages) |
| info["adv_std"] = jnp.std(advantages) |
| return loss, info |
|
|