from __future__ import annotations import torch from hydra.mdlm_decode import ( block_mdlm_decode, mdlm_next_token_logits, validate_mask_token_id, ) class _Out: def __init__(self, logits): self.logits = logits class RecordingMaskModel: def __init__(self, vocab_size: int, mask_id: int): self.vocab_size = vocab_size self.mask_id = mask_id self.calls: list[torch.Tensor] = [] def __call__(self, input_ids): self.calls.append(input_ids.detach().clone()) b, t = input_ids.shape logits = torch.zeros(b, t, self.vocab_size, device=input_ids.device) # Make the best token depend on position, while deliberately making MASK # attractive so decoder helpers must ban it. logits[..., self.mask_id] = 99.0 for pos in range(t): logits[:, pos, (pos + 1) % self.vocab_size] = 100.0 + pos return _Out(logits) def test_validate_mask_token_id_rejects_out_of_vocab_and_bos_collision(): assert validate_mask_token_id(7, vocab_size=8, bos_token_id=0) == 7 try: validate_mask_token_id(8, vocab_size=8) except ValueError as exc: assert "in [0, vocab_size)" in str(exc) else: raise AssertionError("out-of-vocab mask id should fail") try: validate_mask_token_id(0, vocab_size=8, bos_token_id=0) except ValueError as exc: assert "must not equal BOS" in str(exc) else: raise AssertionError("BOS collision should fail") def test_mdlm_next_token_logits_appends_mask_slot_and_bans_mask(): mask_id = 5 model = RecordingMaskModel(vocab_size=8, mask_id=mask_id) prefix = torch.tensor([[1, 2, 3]]) logits = mdlm_next_token_logits(model, prefix, mask_id=mask_id, vocab_size=8) assert model.calls[-1].tolist() == [[1, 2, 3, mask_id]] assert logits.shape == (1, 8) assert torch.isneginf(logits[:, mask_id]).all() assert logits.argmax(dim=-1).item() != mask_id def test_block_mdlm_decode_fills_block_and_never_emits_mask(): mask_id = 5 model = RecordingMaskModel(vocab_size=12, mask_id=mask_id) prefix = torch.tensor([[1, 2]]) out = block_mdlm_decode( model, prefix, mask_id=mask_id, vocab_size=12, block_size=4, refine_steps=2, commit_threshold=0.95, ) assert out.shape == (1, 6) assert out[:, :2].tolist() == [[1, 2]] assert (out[:, 2:] != mask_id).all() # First forward must be prefix + MASK block, not plain AR. assert model.calls[0].tolist() == [[1, 2, mask_id, mask_id, mask_id, mask_id]]