feather-a10g-large-runtime / overlay /tests /test_mdlm_decode.py
icarus112's picture
Update Feather a10g-large training runtime image
c475135 verified
from __future__ import annotations
import torch
from hydra.mdlm_decode import (
block_mdlm_decode,
mdlm_next_token_logits,
validate_mask_token_id,
)
class _Out:
def __init__(self, logits):
self.logits = logits
class RecordingMaskModel:
def __init__(self, vocab_size: int, mask_id: int):
self.vocab_size = vocab_size
self.mask_id = mask_id
self.calls: list[torch.Tensor] = []
def __call__(self, input_ids):
self.calls.append(input_ids.detach().clone())
b, t = input_ids.shape
logits = torch.zeros(b, t, self.vocab_size, device=input_ids.device)
# Make the best token depend on position, while deliberately making MASK
# attractive so decoder helpers must ban it.
logits[..., self.mask_id] = 99.0
for pos in range(t):
logits[:, pos, (pos + 1) % self.vocab_size] = 100.0 + pos
return _Out(logits)
def test_validate_mask_token_id_rejects_out_of_vocab_and_bos_collision():
assert validate_mask_token_id(7, vocab_size=8, bos_token_id=0) == 7
try:
validate_mask_token_id(8, vocab_size=8)
except ValueError as exc:
assert "in [0, vocab_size)" in str(exc)
else:
raise AssertionError("out-of-vocab mask id should fail")
try:
validate_mask_token_id(0, vocab_size=8, bos_token_id=0)
except ValueError as exc:
assert "must not equal BOS" in str(exc)
else:
raise AssertionError("BOS collision should fail")
def test_mdlm_next_token_logits_appends_mask_slot_and_bans_mask():
mask_id = 5
model = RecordingMaskModel(vocab_size=8, mask_id=mask_id)
prefix = torch.tensor([[1, 2, 3]])
logits = mdlm_next_token_logits(model, prefix, mask_id=mask_id, vocab_size=8)
assert model.calls[-1].tolist() == [[1, 2, 3, mask_id]]
assert logits.shape == (1, 8)
assert torch.isneginf(logits[:, mask_id]).all()
assert logits.argmax(dim=-1).item() != mask_id
def test_block_mdlm_decode_fills_block_and_never_emits_mask():
mask_id = 5
model = RecordingMaskModel(vocab_size=12, mask_id=mask_id)
prefix = torch.tensor([[1, 2]])
out = block_mdlm_decode(
model,
prefix,
mask_id=mask_id,
vocab_size=12,
block_size=4,
refine_steps=2,
commit_threshold=0.95,
)
assert out.shape == (1, 6)
assert out[:, :2].tolist() == [[1, 2]]
assert (out[:, 2:] != mask_id).all()
# First forward must be prefix + MASK block, not plain AR.
assert model.calls[0].tolist() == [[1, 2, mask_id, mask_id, mask_id, mask_id]]