File size: 5,293 Bytes
f748552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""MDLM ELBO loss with SUBS parameterisation.

Ported from the Craftax JAX implementation (src/diffusion/loss.py).
Computes continuous-time loss on masked positions only, with analytic
SUBS weighting clipped for numerical stability.
"""

from __future__ import annotations

from typing import Callable

import torch
import torch.nn.functional as F
from torch import Tensor

from src.diffusion.schedules import alpha_prime


_MAX_WEIGHT: float = 1000.0


def mdlm_loss(
    logits: Tensor,
    x0: Tensor,
    zt: Tensor,
    t: Tensor,
    mask_token: int,
    pad_token: int,
    schedule_fn: Callable[[Tensor], Tensor],
    weight_clip: float = _MAX_WEIGHT,
    label_smoothing: float = 0.0,
    use_importance_weighting: bool = False,
) -> Tensor:
    """Compute masked diffusion loss.

    By default uses a simple masked cross-entropy average (matching the
    reference implementation).  When ``use_importance_weighting=True``,
    applies SUBS weighting ``w(t) = -alpha'(t) / (1 - alpha_t)``.

    Args:
        logits: Model output. Shape ``[B, L, vocab]``.
        x0: Clean action sequences. Shape ``[B, L]``, int64.
        zt: Noisy sequences. Shape ``[B, L]``, int64.
        t: Per-sample diffusion time in [0, 1]. Shape ``[B]``.
        mask_token: MASK token ID.
        pad_token: PAD token ID.
        schedule_fn: Noise schedule returning alpha(t).
        weight_clip: Upper clamp for SUBS weight (default 1000).
        label_smoothing: Smoothing epsilon for cross-entropy.
        use_importance_weighting: If ``True``, apply SUBS w(t) per sample.

    Returns:
        Scalar loss. Returns ``0.0`` when no masked positions exist.
    """
    B, L, V = logits.shape

    # Mask: compute loss only on masked, non-PAD positions
    is_masked = (zt == mask_token) & (x0 != pad_token)  # [B, L]

    if not is_masked.any():
        return logits.new_tensor(0.0)

    # Per-position cross-entropy
    # Clamp targets to valid vocab range — out-of-range positions (PAD,
    # MASK) will be zeroed out by is_masked anyway.
    safe_targets = x0.clamp(0, V - 1)  # [B, L]
    ce = F.cross_entropy(
        logits.reshape(-1, V),
        safe_targets.reshape(-1),
        reduction="none",
        label_smoothing=label_smoothing,
    )  # [B*L]
    ce = ce.reshape(B, L)  # [B, L]

    # Zero out non-masked positions
    ce = ce * is_masked.float()  # [B, L]

    # Global average over all masked positions (matches reference)
    n_masked_total = is_masked.float().sum().clamp(min=1.0)
    loss = ce.sum() / n_masked_total

    if use_importance_weighting:
        # SUBS weight: w_t = -alpha'(t) / (1 - alpha_t + eps)
        alpha_t = schedule_fn(t)  # [B]
        d_alpha = alpha_prime(t, schedule_fn)  # [B]
        w_t = (-d_alpha) / (1.0 - alpha_t + 1e-8)  # [B]
        w_t = w_t.clamp(0.0, weight_clip)  # [B]

        # Per-sample weighted loss (needed for SUBS)
        n_masked_per = is_masked.float().sum(dim=1).clamp(min=1.0)  # [B]
        per_sample = ce.sum(dim=1) / n_masked_per  # [B]
        loss = (per_sample * w_t).mean()

    return loss


def auxiliary_goal_loss(
    goal_pred: Tensor,
    global_obs: Tensor,
    pad_value: float = -1.0,
) -> Tensor:
    """MSE loss for auxiliary staircase-coordinate prediction.

    Args:
        goal_pred: Predicted normalised staircase coords. Shape ``[B, 2]``.
        global_obs: Full map glyphs. Shape ``[B, 21, 79]``, int.
        pad_value: Coordinate value used when staircase is not visible.

    Returns:
        Scalar MSE loss over samples where the staircase is visible.
        Returns ``0.0`` when no staircase is visible in the batch.
    """
    targets = find_staircase_from_glyphs(global_obs)  # [B, 2]
    targets = targets.to(goal_pred.device, dtype=goal_pred.dtype)

    # Only supervise where staircase is visible
    valid = (targets[:, 0] != pad_value)  # [B]
    if not valid.any():
        return goal_pred.new_tensor(0.0)

    diff = (goal_pred[valid] - targets[valid]) ** 2  # [N, 2]
    return diff.mean()


def find_staircase_from_glyphs(global_obs: Tensor) -> Tensor:
    """Locate the staircase '>' in the global glyph map.

    Searches for NLE staircase-down glyph (character code 62 = '>').
    Returns normalised (row/H, col/W) coordinates per batch element,
    or (-1, -1) when the staircase is not visible.

    Args:
        global_obs: Glyph map. Shape ``[B, H, W]`` or ``[H, W]``, int.

    Returns:
        Normalised coordinates. Shape ``[B, 2]`` (float32).
    """
    if global_obs.ndim == 2:
        global_obs = global_obs.unsqueeze(0)

    B, H, W = global_obs.shape
    # NLE staircase-down glyphs: ord('>') = 62, plus NLE tile variants
    # 2310 (S_dnstair), 2368 (S_dnstairs), 2383 (S_vodoor).
    is_stair = (
        (global_obs == 62)
        | (global_obs == 2310)
        | (global_obs == 2368)
        | (global_obs == 2383)
    )

    coords = torch.full(
        (B, 2), -1.0, dtype=torch.float32, device=global_obs.device
    )
    for b in range(B):
        positions = is_stair[b].nonzero(as_tuple=False)  # [N, 2]
        if positions.shape[0] > 0:
            row = positions[0, 0].float() / max(1, H - 1)
            col = positions[0, 1].float() / max(1, W - 1)
            coords[b, 0] = row
            coords[b, 1] = col

    return coords