"""The training-free loop patch: K=1 is a bit-exact no-op, K>1 is active and finite, the decode path is guarded, and LoopConfig validates its layer set.""" from __future__ import annotations import pytest import torch from looped_laguna import LoopConfig, patch, unpatch def _logits(model, ids): with torch.no_grad(): return model(input_ids=ids, use_cache=False).logits def test_baseline_forward(model, ids): out = _logits(model, ids) assert out.shape == (ids.shape[0], ids.shape[1], model.config.vocab_size) assert torch.isfinite(out).all() @pytest.mark.parametrize("mode", ["layer", "block"]) @pytest.mark.parametrize("naive", [False, True]) def test_k1_is_bit_exact(model, ids, window, mode, naive): """K=1 must reproduce the baseline forward exactly, for every mode/strategy.""" base = _logits(model, ids) patch(model, LoopConfig(window=window, K=1, mode=mode, naive=naive)) assert torch.equal(_logits(model, ids), base) def test_unpatch_restores_baseline(model, ids, window): base = _logits(model, ids) patch(model, LoopConfig(window=window, K=3, mode="layer")) unpatch(model) assert torch.equal(_logits(model, ids), base) @pytest.mark.parametrize("mode", ["layer", "block"]) @pytest.mark.parametrize("naive", [False, True]) @pytest.mark.parametrize("K", [2, 3]) def test_k_gt1_is_active_and_finite(model, ids, window, mode, naive, K): base = _logits(model, ids) patch(model, LoopConfig(window=window, K=K, mode=mode, naive=naive)) out = _logits(model, ids) assert torch.isfinite(out).all() assert not torch.equal(out, base) # the loop actually changed the output def test_decode_time_is_guarded(model, ids, window): patch(model, LoopConfig(window=window, K=2)) with pytest.raises(NotImplementedError): model(input_ids=ids, use_cache=True) def test_loopconfig_validation(): assert LoopConfig(window=(2, 5)).loop_layers == (2, 3, 4, 5) assert LoopConfig(layers=(4, 8, 12), mode="layer").loop_layers == (4, 8, 12) with pytest.raises(ValueError): # block-mode needs contiguous layers LoopConfig(layers=(4, 8, 12), mode="block") with pytest.raises(ValueError): # exactly one of window/layers LoopConfig() with pytest.raises(ValueError): # a > b LoopConfig(window=(5, 3)) with pytest.raises(ValueError): # K >= 1 LoopConfig(window=(2, 5), K=0)