File size: 2,612 Bytes
c475135 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | from __future__ import annotations
import torch
from hydra.mdlm_decode import (
block_mdlm_decode,
mdlm_next_token_logits,
validate_mask_token_id,
)
class _Out:
def __init__(self, logits):
self.logits = logits
class RecordingMaskModel:
def __init__(self, vocab_size: int, mask_id: int):
self.vocab_size = vocab_size
self.mask_id = mask_id
self.calls: list[torch.Tensor] = []
def __call__(self, input_ids):
self.calls.append(input_ids.detach().clone())
b, t = input_ids.shape
logits = torch.zeros(b, t, self.vocab_size, device=input_ids.device)
# Make the best token depend on position, while deliberately making MASK
# attractive so decoder helpers must ban it.
logits[..., self.mask_id] = 99.0
for pos in range(t):
logits[:, pos, (pos + 1) % self.vocab_size] = 100.0 + pos
return _Out(logits)
def test_validate_mask_token_id_rejects_out_of_vocab_and_bos_collision():
assert validate_mask_token_id(7, vocab_size=8, bos_token_id=0) == 7
try:
validate_mask_token_id(8, vocab_size=8)
except ValueError as exc:
assert "in [0, vocab_size)" in str(exc)
else:
raise AssertionError("out-of-vocab mask id should fail")
try:
validate_mask_token_id(0, vocab_size=8, bos_token_id=0)
except ValueError as exc:
assert "must not equal BOS" in str(exc)
else:
raise AssertionError("BOS collision should fail")
def test_mdlm_next_token_logits_appends_mask_slot_and_bans_mask():
mask_id = 5
model = RecordingMaskModel(vocab_size=8, mask_id=mask_id)
prefix = torch.tensor([[1, 2, 3]])
logits = mdlm_next_token_logits(model, prefix, mask_id=mask_id, vocab_size=8)
assert model.calls[-1].tolist() == [[1, 2, 3, mask_id]]
assert logits.shape == (1, 8)
assert torch.isneginf(logits[:, mask_id]).all()
assert logits.argmax(dim=-1).item() != mask_id
def test_block_mdlm_decode_fills_block_and_never_emits_mask():
mask_id = 5
model = RecordingMaskModel(vocab_size=12, mask_id=mask_id)
prefix = torch.tensor([[1, 2]])
out = block_mdlm_decode(
model,
prefix,
mask_id=mask_id,
vocab_size=12,
block_size=4,
refine_steps=2,
commit_threshold=0.95,
)
assert out.shape == (1, 6)
assert out[:, :2].tolist() == [[1, 2]]
assert (out[:, 2:] != mask_id).all()
# First forward must be prefix + MASK block, not plain AR.
assert model.calls[0].tolist() == [[1, 2, mask_id, mask_id, mask_id, mask_id]]
|