looped-laguna / tests /test_loop.py
e-p's picture
use pytest
7ba6795
"""The training-free loop patch: K=1 is a bit-exact no-op, K>1 is active and finite,
the decode path is guarded, and LoopConfig validates its layer set."""
from __future__ import annotations
import pytest
import torch
from looped_laguna import LoopConfig, patch, unpatch
def _logits(model, ids):
with torch.no_grad():
return model(input_ids=ids, use_cache=False).logits
def test_baseline_forward(model, ids):
out = _logits(model, ids)
assert out.shape == (ids.shape[0], ids.shape[1], model.config.vocab_size)
assert torch.isfinite(out).all()
@pytest.mark.parametrize("mode", ["layer", "block"])
@pytest.mark.parametrize("naive", [False, True])
def test_k1_is_bit_exact(model, ids, window, mode, naive):
"""K=1 must reproduce the baseline forward exactly, for every mode/strategy."""
base = _logits(model, ids)
patch(model, LoopConfig(window=window, K=1, mode=mode, naive=naive))
assert torch.equal(_logits(model, ids), base)
def test_unpatch_restores_baseline(model, ids, window):
base = _logits(model, ids)
patch(model, LoopConfig(window=window, K=3, mode="layer"))
unpatch(model)
assert torch.equal(_logits(model, ids), base)
@pytest.mark.parametrize("mode", ["layer", "block"])
@pytest.mark.parametrize("naive", [False, True])
@pytest.mark.parametrize("K", [2, 3])
def test_k_gt1_is_active_and_finite(model, ids, window, mode, naive, K):
base = _logits(model, ids)
patch(model, LoopConfig(window=window, K=K, mode=mode, naive=naive))
out = _logits(model, ids)
assert torch.isfinite(out).all()
assert not torch.equal(out, base) # the loop actually changed the output
def test_decode_time_is_guarded(model, ids, window):
patch(model, LoopConfig(window=window, K=2))
with pytest.raises(NotImplementedError):
model(input_ids=ids, use_cache=True)
def test_loopconfig_validation():
assert LoopConfig(window=(2, 5)).loop_layers == (2, 3, 4, 5)
assert LoopConfig(layers=(4, 8, 12), mode="layer").loop_layers == (4, 8, 12)
with pytest.raises(ValueError): # block-mode needs contiguous layers
LoopConfig(layers=(4, 8, 12), mode="block")
with pytest.raises(ValueError): # exactly one of window/layers
LoopConfig()
with pytest.raises(ValueError): # a > b
LoopConfig(window=(5, 3))
with pytest.raises(ValueError): # K >= 1
LoopConfig(window=(2, 5), K=0)