"""MDLM-compatible decoding helpers. These helpers make the inference contract explicit for masked-diffusion style checkpoints: predict masked positions, not the ordinary autoregressive final prefix position. They are intentionally model-agnostic so tests can exercise the contract without importing the heavyweight HYDRA model stack. """ from __future__ import annotations from collections.abc import Iterable import torch def validate_mask_token_id( mask_id: int, vocab_size: int, *, bos_token_id: int | None = None, token_bytes: torch.Tensor | None = None, ) -> int: """Validate the reserved MDLM mask token contract. The most dangerous failure is ``MASK_ID == vocab_size`` followed by generic token clamping, which silently turns MASK into the final real token. Refuse that here before noising or decoding. """ mask_id = int(mask_id) vocab_size = int(vocab_size) if not (0 <= mask_id < vocab_size): raise ValueError(f"MDLM mask_id={mask_id} must be in [0, vocab_size) with vocab_size={vocab_size}") if bos_token_id is not None and mask_id == int(bos_token_id): raise ValueError("MDLM mask_id must not equal BOS/PAD/EOS token id") if token_bytes is not None: tb = token_bytes.detach().cpu() if mask_id >= tb.numel(): raise ValueError(f"token_bytes has {tb.numel()} entries, cannot validate mask_id={mask_id}") if int(tb[mask_id].item()) != 0: raise ValueError("MDLM mask_id must be reserved with byte length 0, not a normal text token") return mask_id def _extract_logits(output) -> torch.Tensor: if isinstance(output, torch.Tensor): return output logits = getattr(output, "logits", None) if logits is None: raise TypeError("model output must be a logits tensor or expose .logits") return logits def _ban_ids_(logits: torch.Tensor, banned_ids: Iterable[int] | None, vocab_size: int) -> torch.Tensor: if banned_ids is None: return logits for tok in banned_ids: tok = int(tok) if 0 <= tok < vocab_size: logits[..., tok] = -float("inf") return logits def mdlm_next_token_logits( model, prefix_ids: torch.Tensor, *, mask_id: int, vocab_size: int, banned_ids: Iterable[int] | None = None, ) -> torch.Tensor: """Return next-token logits via a one-token MASK slot. Contract: ``[prefix, MASK] -> logits at MASK``. This is the minimal MDLM inference repair for checkpoints trained to reconstruct masked positions. """ mask_id = validate_mask_token_id(mask_id, vocab_size) mask = torch.full((prefix_ids.shape[0], 1), mask_id, device=prefix_ids.device, dtype=prefix_ids.dtype) x = torch.cat([prefix_ids, mask], dim=1) logits = _extract_logits(model(x))[:, -1, :].float() logits[:, mask_id] = -float("inf") _ban_ids_(logits, banned_ids, vocab_size) return logits @torch.no_grad() def block_mdlm_decode( model, prefix_ids: torch.Tensor, *, mask_id: int, vocab_size: int, block_size: int = 8, refine_steps: int = 4, commit_threshold: float = 0.90, banned_ids: Iterable[int] | None = None, ) -> torch.Tensor: """Semi-autoregressive masked block decoding. Appends ``block_size`` MASK slots, repeatedly predicts all masked slots, and commits a confidence/schedule-selected subset each refinement step. Anything still masked after the final refinement is force-filled by argmax. """ mask_id = validate_mask_token_id(mask_id, vocab_size) if block_size <= 0: raise ValueError("block_size must be positive") if refine_steps <= 0: raise ValueError("refine_steps must be positive") B = prefix_ids.shape[0] block = torch.full((B, block_size), mask_id, dtype=prefix_ids.dtype, device=prefix_ids.device) committed = torch.zeros((B, block_size), dtype=torch.bool, device=prefix_ids.device) for step in range(refine_steps): x = torch.cat([prefix_ids, block], dim=1) logits = _extract_logits(model(x))[:, -block_size:, :].float() logits[:, :, mask_id] = -float("inf") _ban_ids_(logits, banned_ids, vocab_size) probs = logits.softmax(dim=-1) conf, tok = probs.max(dim=-1) threshold_commit = conf >= float(commit_threshold) n_commit = max(1, int((step + 1) / refine_steps * block_size)) top_pos = conf.topk(min(n_commit, block_size), dim=-1).indices schedule_commit = torch.zeros_like(committed) schedule_commit.scatter_(1, top_pos, True) update = (~committed) & (threshold_commit | schedule_commit) block = torch.where(update, tok.to(block.dtype), block) committed |= update if (~committed).any(): x = torch.cat([prefix_ids, block], dim=1) logits = _extract_logits(model(x))[:, -block_size:, :].float() logits[:, :, mask_id] = -float("inf") _ban_ids_(logits, banned_ids, vocab_size) tok = logits.argmax(dim=-1).to(block.dtype) block = torch.where(committed, block, tok) return torch.cat([prefix_ids, block], dim=1)