File size: 5,293 Bytes
f748552 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """MDLM ELBO loss with SUBS parameterisation.
Ported from the Craftax JAX implementation (src/diffusion/loss.py).
Computes continuous-time loss on masked positions only, with analytic
SUBS weighting clipped for numerical stability.
"""
from __future__ import annotations
from typing import Callable
import torch
import torch.nn.functional as F
from torch import Tensor
from src.diffusion.schedules import alpha_prime
_MAX_WEIGHT: float = 1000.0
def mdlm_loss(
logits: Tensor,
x0: Tensor,
zt: Tensor,
t: Tensor,
mask_token: int,
pad_token: int,
schedule_fn: Callable[[Tensor], Tensor],
weight_clip: float = _MAX_WEIGHT,
label_smoothing: float = 0.0,
use_importance_weighting: bool = False,
) -> Tensor:
"""Compute masked diffusion loss.
By default uses a simple masked cross-entropy average (matching the
reference implementation). When ``use_importance_weighting=True``,
applies SUBS weighting ``w(t) = -alpha'(t) / (1 - alpha_t)``.
Args:
logits: Model output. Shape ``[B, L, vocab]``.
x0: Clean action sequences. Shape ``[B, L]``, int64.
zt: Noisy sequences. Shape ``[B, L]``, int64.
t: Per-sample diffusion time in [0, 1]. Shape ``[B]``.
mask_token: MASK token ID.
pad_token: PAD token ID.
schedule_fn: Noise schedule returning alpha(t).
weight_clip: Upper clamp for SUBS weight (default 1000).
label_smoothing: Smoothing epsilon for cross-entropy.
use_importance_weighting: If ``True``, apply SUBS w(t) per sample.
Returns:
Scalar loss. Returns ``0.0`` when no masked positions exist.
"""
B, L, V = logits.shape
# Mask: compute loss only on masked, non-PAD positions
is_masked = (zt == mask_token) & (x0 != pad_token) # [B, L]
if not is_masked.any():
return logits.new_tensor(0.0)
# Per-position cross-entropy
# Clamp targets to valid vocab range — out-of-range positions (PAD,
# MASK) will be zeroed out by is_masked anyway.
safe_targets = x0.clamp(0, V - 1) # [B, L]
ce = F.cross_entropy(
logits.reshape(-1, V),
safe_targets.reshape(-1),
reduction="none",
label_smoothing=label_smoothing,
) # [B*L]
ce = ce.reshape(B, L) # [B, L]
# Zero out non-masked positions
ce = ce * is_masked.float() # [B, L]
# Global average over all masked positions (matches reference)
n_masked_total = is_masked.float().sum().clamp(min=1.0)
loss = ce.sum() / n_masked_total
if use_importance_weighting:
# SUBS weight: w_t = -alpha'(t) / (1 - alpha_t + eps)
alpha_t = schedule_fn(t) # [B]
d_alpha = alpha_prime(t, schedule_fn) # [B]
w_t = (-d_alpha) / (1.0 - alpha_t + 1e-8) # [B]
w_t = w_t.clamp(0.0, weight_clip) # [B]
# Per-sample weighted loss (needed for SUBS)
n_masked_per = is_masked.float().sum(dim=1).clamp(min=1.0) # [B]
per_sample = ce.sum(dim=1) / n_masked_per # [B]
loss = (per_sample * w_t).mean()
return loss
def auxiliary_goal_loss(
goal_pred: Tensor,
global_obs: Tensor,
pad_value: float = -1.0,
) -> Tensor:
"""MSE loss for auxiliary staircase-coordinate prediction.
Args:
goal_pred: Predicted normalised staircase coords. Shape ``[B, 2]``.
global_obs: Full map glyphs. Shape ``[B, 21, 79]``, int.
pad_value: Coordinate value used when staircase is not visible.
Returns:
Scalar MSE loss over samples where the staircase is visible.
Returns ``0.0`` when no staircase is visible in the batch.
"""
targets = find_staircase_from_glyphs(global_obs) # [B, 2]
targets = targets.to(goal_pred.device, dtype=goal_pred.dtype)
# Only supervise where staircase is visible
valid = (targets[:, 0] != pad_value) # [B]
if not valid.any():
return goal_pred.new_tensor(0.0)
diff = (goal_pred[valid] - targets[valid]) ** 2 # [N, 2]
return diff.mean()
def find_staircase_from_glyphs(global_obs: Tensor) -> Tensor:
"""Locate the staircase '>' in the global glyph map.
Searches for NLE staircase-down glyph (character code 62 = '>').
Returns normalised (row/H, col/W) coordinates per batch element,
or (-1, -1) when the staircase is not visible.
Args:
global_obs: Glyph map. Shape ``[B, H, W]`` or ``[H, W]``, int.
Returns:
Normalised coordinates. Shape ``[B, 2]`` (float32).
"""
if global_obs.ndim == 2:
global_obs = global_obs.unsqueeze(0)
B, H, W = global_obs.shape
# NLE staircase-down glyphs: ord('>') = 62, plus NLE tile variants
# 2310 (S_dnstair), 2368 (S_dnstairs), 2383 (S_vodoor).
is_stair = (
(global_obs == 62)
| (global_obs == 2310)
| (global_obs == 2368)
| (global_obs == 2383)
)
coords = torch.full(
(B, 2), -1.0, dtype=torch.float32, device=global_obs.device
)
for b in range(B):
positions = is_stair[b].nonzero(as_tuple=False) # [N, 2]
if positions.shape[0] > 0:
row = positions[0, 0].float() / max(1, H - 1)
col = positions[0, 1].float() / max(1, W - 1)
coords[b, 0] = row
coords[b, 1] = col
return coords
|