File size: 4,323 Bytes
6140064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""MDLM ELBO loss for masked discrete diffusion training."""

from __future__ import annotations
from typing import Any, Callable, Optional

import jax
import jax.numpy as jnp

from .forward import forward_process
from .schedules import ScheduleFn

ModelApplyFn = Callable[
    [Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[Any]], jnp.ndarray
]

_MAX_WEIGHT: float = 1000.0
_EPS: float = 1e-5


def compute_loss(
    model_apply: ModelApplyFn,
    params: Any,
    rng: jax.Array,
    x_0: jnp.ndarray,
    obs: jnp.ndarray,
    valid: jnp.ndarray,
    num_actions: int,
    schedule_fn: ScheduleFn,
    schedule_deriv_fn: ScheduleFn,
    sigma_t: float = 0.0,
    label_smoothing: float = 0.0,
    advantages: Optional[jnp.ndarray] = None,
    t_min: float | jax.Array = _EPS,
    t_max: float | jax.Array = 1.0,
) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
    """Continuous-time ELBO loss on masked positions only.

    Args:
        model_apply:       fn(params, obs, z_t, t, rng) -> logits [B, H, V].
        params:            Model parameters.
        rng:               PRNG key.
        x_0:               [B, H] int32, ground-truth actions.
        obs:               [B, obs_dim] float32, observations.
        valid:             [B] bool/float, whether each sample is valid.
        num_actions:       Size of real action vocabulary.
        schedule_fn:       alpha(t).
        schedule_deriv_fn: d(alpha)/dt (analytic).
        sigma_t:           Remasking correction for ELBO weight (0 = standard MDLM).
        label_smoothing:   Smoothing epsilon (0 = exact ELBO targets).
        advantages:        Optional [B] per-sample weights.
        t_min:             Lower bound for uniform t sampling (default: _EPS).
        t_max:             Upper bound for uniform t sampling (default: 1.0).

    Returns:
        (loss, info_dict).
    """
    B = x_0.shape[0]
    mask_id = num_actions
    rng, t_rng, mask_rng, drop_rng = jax.random.split(rng, 4)

    # Sample t ~ U(t_min, t_max). Defaults give full ELBO; narrow range for ablations.
    t = jax.random.uniform(t_rng, (B,), minval=t_min, maxval=t_max)
    alpha_t = schedule_fn(t)

    # Analytic loss weight: w(t) = (1 - sigma) * (-d(alpha)/dt) / (1 - alpha(t))
    neg_alpha_dot = -schedule_deriv_fn(t)  # positive quantity
    weight = (1.0 - sigma_t) * neg_alpha_dot / jnp.maximum(1.0 - alpha_t, _EPS)
    weight = jnp.minimum(weight, _MAX_WEIGHT)

    # Forward noise
    z_t = forward_process(mask_rng, x_0, alpha_t, mask_id)

    # Model prediction
    logits = model_apply(params, obs, z_t, t, drop_rng)  # [B, H, V]

    # Cross-entropy on valid masked positions
    is_masked = (z_t == mask_id).astype(jnp.float32)        # [B, H]
    valid_masked = is_masked * valid[:, None].astype(jnp.float32)  # [B, H]

    targets = jax.nn.one_hot(x_0, num_actions)
    if label_smoothing > 0:
        targets = (1.0 - label_smoothing) * targets + label_smoothing / num_actions

    log_probs = jax.nn.log_softmax(logits, axis=-1)
    ce = -jnp.sum(targets * log_probs, axis=-1)             # [B, H]

    n_masked = jnp.maximum(valid_masked.sum(axis=-1), 1.0)  # [B]
    per_sample = weight * (ce * valid_masked).sum(axis=-1) / n_masked

    if advantages is not None:
        per_sample = per_sample * jax.lax.stop_gradient(advantages)

    loss = jnp.mean(per_sample)

    # Diagnostics
    preds = jnp.argmax(logits, axis=-1)
    correct = (preds == x_0).astype(jnp.float32)
    acc = jnp.sum(correct * valid_masked) / jnp.maximum(valid_masked.sum(), 1.0)

    t_bins = jnp.array([0.33, 0.66])
    t_lo = (t < t_bins[0])[:, None]
    t_mi = ((t >= t_bins[0]) & (t <= t_bins[1]))[:, None]
    t_hi = (t > t_bins[1])[:, None]

    def _binned_acc(mask):
        m = valid_masked * mask
        return jnp.sum(correct * m) / jnp.maximum(m.sum(), 1.0)

    info = {
        "loss": loss,
        "unweighted_loss": jnp.mean((ce * valid_masked).sum(axis=-1) / n_masked),
        "mean_t": jnp.mean(t),
        "frac_masked": jnp.mean(is_masked),
        "accuracy": acc,
        "acc_t_low": _binned_acc(t_lo),
        "acc_t_mid": _binned_acc(t_mi),
        "acc_t_high": _binned_acc(t_hi),
    }
    if advantages is not None:
        info["adv_mean"] = jnp.mean(advantages)
        info["adv_std"] = jnp.std(advantages)
    return loss, info