| """MDLM-compatible decoding helpers. |
| |
| These helpers make the inference contract explicit for masked-diffusion style |
| checkpoints: predict masked positions, not the ordinary autoregressive final |
| prefix position. They are intentionally model-agnostic so tests can exercise the |
| contract without importing the heavyweight HYDRA model stack. |
| """ |
| from __future__ import annotations |
|
|
| from collections.abc import Iterable |
|
|
| import torch |
|
|
|
|
| def validate_mask_token_id( |
| mask_id: int, |
| vocab_size: int, |
| *, |
| bos_token_id: int | None = None, |
| token_bytes: torch.Tensor | None = None, |
| ) -> int: |
| """Validate the reserved MDLM mask token contract. |
| |
| The most dangerous failure is ``MASK_ID == vocab_size`` followed by generic |
| token clamping, which silently turns MASK into the final real token. Refuse |
| that here before noising or decoding. |
| """ |
| mask_id = int(mask_id) |
| vocab_size = int(vocab_size) |
| if not (0 <= mask_id < vocab_size): |
| raise ValueError(f"MDLM mask_id={mask_id} must be in [0, vocab_size) with vocab_size={vocab_size}") |
| if bos_token_id is not None and mask_id == int(bos_token_id): |
| raise ValueError("MDLM mask_id must not equal BOS/PAD/EOS token id") |
| if token_bytes is not None: |
| tb = token_bytes.detach().cpu() |
| if mask_id >= tb.numel(): |
| raise ValueError(f"token_bytes has {tb.numel()} entries, cannot validate mask_id={mask_id}") |
| if int(tb[mask_id].item()) != 0: |
| raise ValueError("MDLM mask_id must be reserved with byte length 0, not a normal text token") |
| return mask_id |
|
|
|
|
| def _extract_logits(output) -> torch.Tensor: |
| if isinstance(output, torch.Tensor): |
| return output |
| logits = getattr(output, "logits", None) |
| if logits is None: |
| raise TypeError("model output must be a logits tensor or expose .logits") |
| return logits |
|
|
|
|
| def _ban_ids_(logits: torch.Tensor, banned_ids: Iterable[int] | None, vocab_size: int) -> torch.Tensor: |
| if banned_ids is None: |
| return logits |
| for tok in banned_ids: |
| tok = int(tok) |
| if 0 <= tok < vocab_size: |
| logits[..., tok] = -float("inf") |
| return logits |
|
|
|
|
| def mdlm_next_token_logits( |
| model, |
| prefix_ids: torch.Tensor, |
| *, |
| mask_id: int, |
| vocab_size: int, |
| banned_ids: Iterable[int] | None = None, |
| ) -> torch.Tensor: |
| """Return next-token logits via a one-token MASK slot. |
| |
| Contract: ``[prefix, MASK] -> logits at MASK``. This is the minimal MDLM |
| inference repair for checkpoints trained to reconstruct masked positions. |
| """ |
| mask_id = validate_mask_token_id(mask_id, vocab_size) |
| mask = torch.full((prefix_ids.shape[0], 1), mask_id, device=prefix_ids.device, dtype=prefix_ids.dtype) |
| x = torch.cat([prefix_ids, mask], dim=1) |
| logits = _extract_logits(model(x))[:, -1, :].float() |
| logits[:, mask_id] = -float("inf") |
| _ban_ids_(logits, banned_ids, vocab_size) |
| return logits |
|
|
|
|
| @torch.no_grad() |
| def block_mdlm_decode( |
| model, |
| prefix_ids: torch.Tensor, |
| *, |
| mask_id: int, |
| vocab_size: int, |
| block_size: int = 8, |
| refine_steps: int = 4, |
| commit_threshold: float = 0.90, |
| banned_ids: Iterable[int] | None = None, |
| ) -> torch.Tensor: |
| """Semi-autoregressive masked block decoding. |
| |
| Appends ``block_size`` MASK slots, repeatedly predicts all masked slots, and |
| commits a confidence/schedule-selected subset each refinement step. Anything |
| still masked after the final refinement is force-filled by argmax. |
| """ |
| mask_id = validate_mask_token_id(mask_id, vocab_size) |
| if block_size <= 0: |
| raise ValueError("block_size must be positive") |
| if refine_steps <= 0: |
| raise ValueError("refine_steps must be positive") |
|
|
| B = prefix_ids.shape[0] |
| block = torch.full((B, block_size), mask_id, dtype=prefix_ids.dtype, device=prefix_ids.device) |
| committed = torch.zeros((B, block_size), dtype=torch.bool, device=prefix_ids.device) |
|
|
| for step in range(refine_steps): |
| x = torch.cat([prefix_ids, block], dim=1) |
| logits = _extract_logits(model(x))[:, -block_size:, :].float() |
| logits[:, :, mask_id] = -float("inf") |
| _ban_ids_(logits, banned_ids, vocab_size) |
| probs = logits.softmax(dim=-1) |
| conf, tok = probs.max(dim=-1) |
|
|
| threshold_commit = conf >= float(commit_threshold) |
| n_commit = max(1, int((step + 1) / refine_steps * block_size)) |
| top_pos = conf.topk(min(n_commit, block_size), dim=-1).indices |
| schedule_commit = torch.zeros_like(committed) |
| schedule_commit.scatter_(1, top_pos, True) |
| update = (~committed) & (threshold_commit | schedule_commit) |
| block = torch.where(update, tok.to(block.dtype), block) |
| committed |= update |
|
|
| if (~committed).any(): |
| x = torch.cat([prefix_ids, block], dim=1) |
| logits = _extract_logits(model(x))[:, -block_size:, :].float() |
| logits[:, :, mask_id] = -float("inf") |
| _ban_ids_(logits, banned_ids, vocab_size) |
| tok = logits.argmax(dim=-1).to(block.dtype) |
| block = torch.where(committed, block, tok) |
|
|
| return torch.cat([prefix_ids, block], dim=1) |
|
|