File size: 2,414 Bytes
7ba6795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""The training-free loop patch: K=1 is a bit-exact no-op, K>1 is active and finite,
the decode path is guarded, and LoopConfig validates its layer set."""

from __future__ import annotations

import pytest
import torch

from looped_laguna import LoopConfig, patch, unpatch


def _logits(model, ids):
    with torch.no_grad():
        return model(input_ids=ids, use_cache=False).logits


def test_baseline_forward(model, ids):
    out = _logits(model, ids)
    assert out.shape == (ids.shape[0], ids.shape[1], model.config.vocab_size)
    assert torch.isfinite(out).all()


@pytest.mark.parametrize("mode", ["layer", "block"])
@pytest.mark.parametrize("naive", [False, True])
def test_k1_is_bit_exact(model, ids, window, mode, naive):
    """K=1 must reproduce the baseline forward exactly, for every mode/strategy."""
    base = _logits(model, ids)
    patch(model, LoopConfig(window=window, K=1, mode=mode, naive=naive))
    assert torch.equal(_logits(model, ids), base)


def test_unpatch_restores_baseline(model, ids, window):
    base = _logits(model, ids)
    patch(model, LoopConfig(window=window, K=3, mode="layer"))
    unpatch(model)
    assert torch.equal(_logits(model, ids), base)


@pytest.mark.parametrize("mode", ["layer", "block"])
@pytest.mark.parametrize("naive", [False, True])
@pytest.mark.parametrize("K", [2, 3])
def test_k_gt1_is_active_and_finite(model, ids, window, mode, naive, K):
    base = _logits(model, ids)
    patch(model, LoopConfig(window=window, K=K, mode=mode, naive=naive))
    out = _logits(model, ids)
    assert torch.isfinite(out).all()
    assert not torch.equal(out, base)  # the loop actually changed the output


def test_decode_time_is_guarded(model, ids, window):
    patch(model, LoopConfig(window=window, K=2))
    with pytest.raises(NotImplementedError):
        model(input_ids=ids, use_cache=True)


def test_loopconfig_validation():
    assert LoopConfig(window=(2, 5)).loop_layers == (2, 3, 4, 5)
    assert LoopConfig(layers=(4, 8, 12), mode="layer").loop_layers == (4, 8, 12)
    with pytest.raises(ValueError):  # block-mode needs contiguous layers
        LoopConfig(layers=(4, 8, 12), mode="block")
    with pytest.raises(ValueError):  # exactly one of window/layers
        LoopConfig()
    with pytest.raises(ValueError):  # a > b
        LoopConfig(window=(5, 3))
    with pytest.raises(ValueError):  # K >= 1
        LoopConfig(window=(2, 5), K=0)