| """The training-free loop patch: K=1 is a bit-exact no-op, K>1 is active and finite, |
| the decode path is guarded, and LoopConfig validates its layer set.""" |
|
|
| from __future__ import annotations |
|
|
| import pytest |
| import torch |
|
|
| from looped_laguna import LoopConfig, patch, unpatch |
|
|
|
|
| def _logits(model, ids): |
| with torch.no_grad(): |
| return model(input_ids=ids, use_cache=False).logits |
|
|
|
|
| def test_baseline_forward(model, ids): |
| out = _logits(model, ids) |
| assert out.shape == (ids.shape[0], ids.shape[1], model.config.vocab_size) |
| assert torch.isfinite(out).all() |
|
|
|
|
| @pytest.mark.parametrize("mode", ["layer", "block"]) |
| @pytest.mark.parametrize("naive", [False, True]) |
| def test_k1_is_bit_exact(model, ids, window, mode, naive): |
| """K=1 must reproduce the baseline forward exactly, for every mode/strategy.""" |
| base = _logits(model, ids) |
| patch(model, LoopConfig(window=window, K=1, mode=mode, naive=naive)) |
| assert torch.equal(_logits(model, ids), base) |
|
|
|
|
| def test_unpatch_restores_baseline(model, ids, window): |
| base = _logits(model, ids) |
| patch(model, LoopConfig(window=window, K=3, mode="layer")) |
| unpatch(model) |
| assert torch.equal(_logits(model, ids), base) |
|
|
|
|
| @pytest.mark.parametrize("mode", ["layer", "block"]) |
| @pytest.mark.parametrize("naive", [False, True]) |
| @pytest.mark.parametrize("K", [2, 3]) |
| def test_k_gt1_is_active_and_finite(model, ids, window, mode, naive, K): |
| base = _logits(model, ids) |
| patch(model, LoopConfig(window=window, K=K, mode=mode, naive=naive)) |
| out = _logits(model, ids) |
| assert torch.isfinite(out).all() |
| assert not torch.equal(out, base) |
|
|
|
|
| def test_decode_time_is_guarded(model, ids, window): |
| patch(model, LoopConfig(window=window, K=2)) |
| with pytest.raises(NotImplementedError): |
| model(input_ids=ids, use_cache=True) |
|
|
|
|
| def test_loopconfig_validation(): |
| assert LoopConfig(window=(2, 5)).loop_layers == (2, 3, 4, 5) |
| assert LoopConfig(layers=(4, 8, 12), mode="layer").loop_layers == (4, 8, 12) |
| with pytest.raises(ValueError): |
| LoopConfig(layers=(4, 8, 12), mode="block") |
| with pytest.raises(ValueError): |
| LoopConfig() |
| with pytest.raises(ValueError): |
| LoopConfig(window=(5, 3)) |
| with pytest.raises(ValueError): |
| LoopConfig(window=(2, 5), K=0) |
|
|