| from __future__ import annotations |
|
|
| import torch |
|
|
| from hydra.mdlm_decode import ( |
| block_mdlm_decode, |
| mdlm_next_token_logits, |
| validate_mask_token_id, |
| ) |
|
|
|
|
| class _Out: |
| def __init__(self, logits): |
| self.logits = logits |
|
|
|
|
| class RecordingMaskModel: |
| def __init__(self, vocab_size: int, mask_id: int): |
| self.vocab_size = vocab_size |
| self.mask_id = mask_id |
| self.calls: list[torch.Tensor] = [] |
|
|
| def __call__(self, input_ids): |
| self.calls.append(input_ids.detach().clone()) |
| b, t = input_ids.shape |
| logits = torch.zeros(b, t, self.vocab_size, device=input_ids.device) |
| |
| |
| logits[..., self.mask_id] = 99.0 |
| for pos in range(t): |
| logits[:, pos, (pos + 1) % self.vocab_size] = 100.0 + pos |
| return _Out(logits) |
|
|
|
|
| def test_validate_mask_token_id_rejects_out_of_vocab_and_bos_collision(): |
| assert validate_mask_token_id(7, vocab_size=8, bos_token_id=0) == 7 |
| try: |
| validate_mask_token_id(8, vocab_size=8) |
| except ValueError as exc: |
| assert "in [0, vocab_size)" in str(exc) |
| else: |
| raise AssertionError("out-of-vocab mask id should fail") |
|
|
| try: |
| validate_mask_token_id(0, vocab_size=8, bos_token_id=0) |
| except ValueError as exc: |
| assert "must not equal BOS" in str(exc) |
| else: |
| raise AssertionError("BOS collision should fail") |
|
|
|
|
| def test_mdlm_next_token_logits_appends_mask_slot_and_bans_mask(): |
| mask_id = 5 |
| model = RecordingMaskModel(vocab_size=8, mask_id=mask_id) |
| prefix = torch.tensor([[1, 2, 3]]) |
|
|
| logits = mdlm_next_token_logits(model, prefix, mask_id=mask_id, vocab_size=8) |
|
|
| assert model.calls[-1].tolist() == [[1, 2, 3, mask_id]] |
| assert logits.shape == (1, 8) |
| assert torch.isneginf(logits[:, mask_id]).all() |
| assert logits.argmax(dim=-1).item() != mask_id |
|
|
|
|
| def test_block_mdlm_decode_fills_block_and_never_emits_mask(): |
| mask_id = 5 |
| model = RecordingMaskModel(vocab_size=12, mask_id=mask_id) |
| prefix = torch.tensor([[1, 2]]) |
|
|
| out = block_mdlm_decode( |
| model, |
| prefix, |
| mask_id=mask_id, |
| vocab_size=12, |
| block_size=4, |
| refine_steps=2, |
| commit_threshold=0.95, |
| ) |
|
|
| assert out.shape == (1, 6) |
| assert out[:, :2].tolist() == [[1, 2]] |
| assert (out[:, 2:] != mask_id).all() |
| |
| assert model.calls[0].tolist() == [[1, 2, mask_id, mask_id, mask_id, mask_id]] |
|
|