"""Tests for LTI injection in ACT loops.""" import torch from arbitor.components import LTIInjection, ByteHead from arbitor.decoders import VideoHead def test_lti_basic_properties(): lti = LTIInjection(64) h = torch.randn(2, 10, 64) e = torch.randn(2, 10, 64) t = torch.randn(2, 10, 64) out = lti(h, e, t) assert out.shape == h.shape assert torch.isfinite(out).all() def test_lti_spectral_radius(): lti = LTIInjection(64) A = lti.get_A() assert (A > 0).all() assert (A < 1).all() def test_lti_learnable_params(): lti = LTIInjection(128) assert lti.log_A.shape == (128,) assert lti.log_dt.shape == (1,) assert lti.B.shape == (128,) assert sum(p.numel() for p in lti.parameters()) == 128 + 1 + 128 def test_lti_state_decay(): lti = LTIInjection(8) h = torch.ones(1, 1, 8) * 100.0 e = torch.zeros(1, 1, 8) t = torch.zeros(1, 1, 8) out = lti(h, e, t) assert (out.abs() < 50).all() def test_lti_initial_state_small(): lti = LTIInjection(8) h = torch.zeros(1, 1, 8) e = torch.ones(1, 1, 8) * 5.0 t = torch.zeros(1, 1, 8) out = lti(h, e, t) assert (out > 0).all() assert (out < 5).all() def test_bytehead_lti_integration(): bh = ByteHead() x = torch.randn(2, 10, 8192) logits = bh(x) assert logits.shape[-1] == 288 assert bh.lti is not None assert isinstance(bh.lti, LTIInjection) def test_bytehead_no_act(): bh_single = ByteHead(act_max_iters=1) assert bh_single.lti is None x = torch.randn(1, 5, 8192) logits = bh_single(x) assert logits.shape[-1] == 288