icarus112's picture
Upload folder using huggingface_hub
c383594 verified
"""MDLM-compatible decoding helpers.
These helpers make the inference contract explicit for masked-diffusion style
checkpoints: predict masked positions, not the ordinary autoregressive final
prefix position. They are intentionally model-agnostic so tests can exercise the
contract without importing the heavyweight HYDRA model stack.
"""
from __future__ import annotations
from collections.abc import Iterable
import torch
def validate_mask_token_id(
mask_id: int,
vocab_size: int,
*,
bos_token_id: int | None = None,
token_bytes: torch.Tensor | None = None,
) -> int:
"""Validate the reserved MDLM mask token contract.
The most dangerous failure is ``MASK_ID == vocab_size`` followed by generic
token clamping, which silently turns MASK into the final real token. Refuse
that here before noising or decoding.
"""
mask_id = int(mask_id)
vocab_size = int(vocab_size)
if not (0 <= mask_id < vocab_size):
raise ValueError(f"MDLM mask_id={mask_id} must be in [0, vocab_size) with vocab_size={vocab_size}")
if bos_token_id is not None and mask_id == int(bos_token_id):
raise ValueError("MDLM mask_id must not equal BOS/PAD/EOS token id")
if token_bytes is not None:
tb = token_bytes.detach().cpu()
if mask_id >= tb.numel():
raise ValueError(f"token_bytes has {tb.numel()} entries, cannot validate mask_id={mask_id}")
if int(tb[mask_id].item()) != 0:
raise ValueError("MDLM mask_id must be reserved with byte length 0, not a normal text token")
return mask_id
def _extract_logits(output) -> torch.Tensor:
if isinstance(output, torch.Tensor):
return output
logits = getattr(output, "logits", None)
if logits is None:
raise TypeError("model output must be a logits tensor or expose .logits")
return logits
def _ban_ids_(logits: torch.Tensor, banned_ids: Iterable[int] | None, vocab_size: int) -> torch.Tensor:
if banned_ids is None:
return logits
for tok in banned_ids:
tok = int(tok)
if 0 <= tok < vocab_size:
logits[..., tok] = -float("inf")
return logits
def mdlm_next_token_logits(
model,
prefix_ids: torch.Tensor,
*,
mask_id: int,
vocab_size: int,
banned_ids: Iterable[int] | None = None,
) -> torch.Tensor:
"""Return next-token logits via a one-token MASK slot.
Contract: ``[prefix, MASK] -> logits at MASK``. This is the minimal MDLM
inference repair for checkpoints trained to reconstruct masked positions.
"""
mask_id = validate_mask_token_id(mask_id, vocab_size)
mask = torch.full((prefix_ids.shape[0], 1), mask_id, device=prefix_ids.device, dtype=prefix_ids.dtype)
x = torch.cat([prefix_ids, mask], dim=1)
logits = _extract_logits(model(x))[:, -1, :].float()
logits[:, mask_id] = -float("inf")
_ban_ids_(logits, banned_ids, vocab_size)
return logits
@torch.no_grad()
def block_mdlm_decode(
model,
prefix_ids: torch.Tensor,
*,
mask_id: int,
vocab_size: int,
block_size: int = 8,
refine_steps: int = 4,
commit_threshold: float = 0.90,
banned_ids: Iterable[int] | None = None,
) -> torch.Tensor:
"""Semi-autoregressive masked block decoding.
Appends ``block_size`` MASK slots, repeatedly predicts all masked slots, and
commits a confidence/schedule-selected subset each refinement step. Anything
still masked after the final refinement is force-filled by argmax.
"""
mask_id = validate_mask_token_id(mask_id, vocab_size)
if block_size <= 0:
raise ValueError("block_size must be positive")
if refine_steps <= 0:
raise ValueError("refine_steps must be positive")
B = prefix_ids.shape[0]
block = torch.full((B, block_size), mask_id, dtype=prefix_ids.dtype, device=prefix_ids.device)
committed = torch.zeros((B, block_size), dtype=torch.bool, device=prefix_ids.device)
for step in range(refine_steps):
x = torch.cat([prefix_ids, block], dim=1)
logits = _extract_logits(model(x))[:, -block_size:, :].float()
logits[:, :, mask_id] = -float("inf")
_ban_ids_(logits, banned_ids, vocab_size)
probs = logits.softmax(dim=-1)
conf, tok = probs.max(dim=-1)
threshold_commit = conf >= float(commit_threshold)
n_commit = max(1, int((step + 1) / refine_steps * block_size))
top_pos = conf.topk(min(n_commit, block_size), dim=-1).indices
schedule_commit = torch.zeros_like(committed)
schedule_commit.scatter_(1, top_pos, True)
update = (~committed) & (threshold_commit | schedule_commit)
block = torch.where(update, tok.to(block.dtype), block)
committed |= update
if (~committed).any():
x = torch.cat([prefix_ids, block], dim=1)
logits = _extract_logits(model(x))[:, -block_size:, :].float()
logits[:, :, mask_id] = -float("inf")
_ban_ids_(logits, banned_ids, vocab_size)
tok = logits.argmax(dim=-1).to(block.dtype)
block = torch.where(committed, block, tok)
return torch.cat([prefix_ids, block], dim=1)