File size: 2,612 Bytes
c475135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations

import torch

from hydra.mdlm_decode import (
    block_mdlm_decode,
    mdlm_next_token_logits,
    validate_mask_token_id,
)


class _Out:
    def __init__(self, logits):
        self.logits = logits


class RecordingMaskModel:
    def __init__(self, vocab_size: int, mask_id: int):
        self.vocab_size = vocab_size
        self.mask_id = mask_id
        self.calls: list[torch.Tensor] = []

    def __call__(self, input_ids):
        self.calls.append(input_ids.detach().clone())
        b, t = input_ids.shape
        logits = torch.zeros(b, t, self.vocab_size, device=input_ids.device)
        # Make the best token depend on position, while deliberately making MASK
        # attractive so decoder helpers must ban it.
        logits[..., self.mask_id] = 99.0
        for pos in range(t):
            logits[:, pos, (pos + 1) % self.vocab_size] = 100.0 + pos
        return _Out(logits)


def test_validate_mask_token_id_rejects_out_of_vocab_and_bos_collision():
    assert validate_mask_token_id(7, vocab_size=8, bos_token_id=0) == 7
    try:
        validate_mask_token_id(8, vocab_size=8)
    except ValueError as exc:
        assert "in [0, vocab_size)" in str(exc)
    else:
        raise AssertionError("out-of-vocab mask id should fail")

    try:
        validate_mask_token_id(0, vocab_size=8, bos_token_id=0)
    except ValueError as exc:
        assert "must not equal BOS" in str(exc)
    else:
        raise AssertionError("BOS collision should fail")


def test_mdlm_next_token_logits_appends_mask_slot_and_bans_mask():
    mask_id = 5
    model = RecordingMaskModel(vocab_size=8, mask_id=mask_id)
    prefix = torch.tensor([[1, 2, 3]])

    logits = mdlm_next_token_logits(model, prefix, mask_id=mask_id, vocab_size=8)

    assert model.calls[-1].tolist() == [[1, 2, 3, mask_id]]
    assert logits.shape == (1, 8)
    assert torch.isneginf(logits[:, mask_id]).all()
    assert logits.argmax(dim=-1).item() != mask_id


def test_block_mdlm_decode_fills_block_and_never_emits_mask():
    mask_id = 5
    model = RecordingMaskModel(vocab_size=12, mask_id=mask_id)
    prefix = torch.tensor([[1, 2]])

    out = block_mdlm_decode(
        model,
        prefix,
        mask_id=mask_id,
        vocab_size=12,
        block_size=4,
        refine_steps=2,
        commit_threshold=0.95,
    )

    assert out.shape == (1, 6)
    assert out[:, :2].tolist() == [[1, 2]]
    assert (out[:, 2:] != mask_id).all()
    # First forward must be prefix + MASK block, not plain AR.
    assert model.calls[0].tolist() == [[1, 2, mask_id, mask_id, mask_id, mask_id]]