"""Tests for Gamma SSM Block.""" import torch import pytest from gamma_space_model import GammaSingleBlock class TestGammaSingleBlockInitialization: """Test GammaSingleBlock initialization.""" def test_direct_parameter_init(self): """Test GammaSingleBlock with direct parameters (no config).""" block = GammaSingleBlock( d_model=16, hidden_dim=32, delta_t=0.1, kernel_length=4, A_type="tridiagonal", prenorm=True, residual_scale=1.0, dropout=0.0, ) assert block.d_model == 16 assert block.prenorm is True assert block.residual_scale == 1.0 assert block.dropout_p == 0.0 def test_default_parameters(self): """Test that default parameters are set correctly.""" block = GammaSingleBlock(d_model=16, hidden_dim=32) assert block.d_model == 16 assert block.prenorm is True assert block.residual_scale == 1.0 assert block.dropout_p == 0.0 assert block.ssm.delta_t == 0.1 def test_ssm_instantiation(self): """Test that SSM block is correctly instantiated.""" block = GammaSingleBlock( d_model=16, hidden_dim=32, delta_t=0.2, A_type="tridiagonal", ) assert block.ssm.state_dim == 16 assert block.ssm.hidden_dim == 32 assert block.ssm.delta_t == 0.2 assert block.ssm.A_type == "tridiagonal" class TestGammaSingleBlockForwardPass: """Test GammaSingleBlock forward pass.""" def test_forward_output_shape(self): """Test that forward pass produces correct output shape.""" batch_size, seq_len, d_model = 4, 32, 16 hidden_dim = 32 block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim) x = torch.randn(batch_size, seq_len, d_model) output, final_state = block(x) assert output.shape == (batch_size, seq_len, d_model) assert final_state.shape == (batch_size, hidden_dim) def test_forward_with_initial_state(self): """Test forward pass with provided initial state.""" batch_size, seq_len, d_model, hidden_dim = 2, 16, 8, 16 block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim) x = torch.randn(batch_size, seq_len, d_model) initial_state = torch.zeros(batch_size, hidden_dim) output1, final_state1 = block(x, state=initial_state) output2, final_state2 = block(x, state=None) # Should produce the same results since initial_state defaults to zeros assert torch.allclose(output1, output2, atol=1e-5) class TestGammaSingleBlockNormalization: """Test GammaSingleBlock normalization (prenorm vs postnorm).""" def test_prenorm_configuration(self): """Test prenorm configuration.""" block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=True) x = torch.randn(2, 10, 16) output, _ = block(x) assert output.shape == (2, 10, 16) def test_postnorm_configuration(self): """Test postnorm configuration.""" block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=False) x = torch.randn(2, 10, 16) output, _ = block(x) assert output.shape == (2, 10, 16) def test_prenorm_vs_postnorm_outputs_differ(self): """Test that prenorm and postnorm produce different outputs.""" x = torch.randn(2, 10, 16) prenorm_block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=True) postnorm_block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=False) output_pre, _ = prenorm_block(x) output_post, _ = postnorm_block(x) # Outputs should differ but have same shape assert output_pre.shape == output_post.shape assert not torch.allclose(output_pre, output_post) class TestGammaSingleBlockResidualConnection: """Test GammaSingleBlock residual connection.""" def test_residual_with_scale_1(self): """Test residual connection with scale=1.0.""" block = GammaSingleBlock( d_model=16, hidden_dim=32, residual_scale=1.0, prenorm=True, ) x = torch.randn(2, 10, 16) output, _ = block(x) # Output should be x + SSM_output (approximately close to x) assert torch.allclose(output, x, atol=2.0) def test_residual_with_scale_0(self): """Test residual connection with scale=0.0 (no residual).""" block = GammaSingleBlock( d_model=16, hidden_dim=32, residual_scale=0.0, prenorm=True, ) x = torch.randn(2, 10, 16) output, _ = block(x) # Output should be purely from norm + SSM (not affected by input) # It will still be different from x assert not torch.allclose(output, x) def test_residual_scale_effect(self): """Test that residual_scale parameter affects output.""" x = torch.randn(2, 10, 16) block1 = GammaSingleBlock( d_model=16, hidden_dim=32, residual_scale=0.5, prenorm=True, ) block2 = GammaSingleBlock( d_model=16, hidden_dim=32, residual_scale=2.0, prenorm=True, ) output1, _ = block1(x) output2, _ = block2(x) # Different scales should produce different outputs assert not torch.allclose(output1, output2) class TestGammaSingleBlockDropout: """Test GammaSingleBlock dropout.""" def test_dropout_train_mode(self): """Test that dropout is applied during training.""" block = GammaSingleBlock( d_model=16, hidden_dim=32, dropout=0.5, ) block.train() x = torch.randn(2, 10, 16) # Multiple forward passes should give different results due to dropout output1, _ = block(x) output2, _ = block(x) assert not torch.allclose(output1, output2) def test_dropout_eval_mode(self): """Test that dropout is not applied during evaluation.""" block = GammaSingleBlock( d_model=16, hidden_dim=32, dropout=0.5, ) block.eval() x = torch.randn(2, 10, 16) # Multiple forward passes should give same results in eval mode output1, _ = block(x) output2, _ = block(x) assert torch.allclose(output1, output2) def test_no_dropout_with_zero_dropout_rate(self): """Test that no dropout is applied when dropout=0.""" block = GammaSingleBlock( d_model=16, hidden_dim=32, dropout=0.0, ) # Should not have dropout layer assert block.dropout is None class TestGammaSingleBlockMasking: """Test GammaSingleBlock masking functionality.""" def test_forward_with_mask(self): """Test forward pass with masking.""" batch_size, seq_len, d_model = 2, 10, 16 hidden_dim = 32 block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim) x = torch.randn(batch_size, seq_len, d_model) mask = torch.ones(batch_size, seq_len, dtype=torch.bool) mask[1, 5:] = False output, _ = block(x, mask=mask) assert output.shape == (batch_size, seq_len, d_model) class TestGammaSingleBlockGradients: """Test gradient flow through GammaSingleBlock.""" def test_backward_pass(self): """Test that gradients flow correctly.""" batch_size, seq_len, d_model = 2, 10, 16 hidden_dim = 32 block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim) x = torch.randn(batch_size, seq_len, d_model, requires_grad=True) output, _ = block(x) loss = output.sum() loss.backward() assert x.grad is not None assert x.grad.shape == x.shape # Check that block parameters have gradients for param in block.parameters(): if param.requires_grad: assert param.grad is not None def test_gradient_flow_prenorm(self): """Test gradient flow with prenorm.""" block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=True) x = torch.randn(2, 10, 16, requires_grad=True) output, _ = block(x) loss = output.sum() loss.backward() assert x.grad is not None def test_gradient_flow_postnorm(self): """Test gradient flow with postnorm.""" block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=False) x = torch.randn(2, 10, 16, requires_grad=True) output, _ = block(x) loss = output.sum() loss.backward() assert x.grad is not None class TestGammaSingleBlockIntegration: """Integration tests for GammaSingleBlock.""" def test_stacked_blocks(self): """Test stacking multiple blocks together.""" d_model, hidden_dim = 16, 32 num_blocks = 3 blocks = [ GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim) for _ in range(num_blocks) ] x = torch.randn(2, 10, d_model) states = [] # Forward through blocks for block in blocks: x, state = block(x) states.append(state) assert x.shape == (2, 10, d_model) assert len(states) == num_blocks def test_device_transfer(self): """Test that block can be transferred between devices.""" block = GammaSingleBlock(d_model=16, hidden_dim=32) # Test on CPU x_cpu = torch.randn(2, 10, 16) output_cpu, _ = block(x_cpu) assert output_cpu.device.type == "cpu" if torch.cuda.is_available(): # Transfer to GPU block = block.cuda() x_gpu = torch.randn(2, 10, 16).cuda() output_gpu, _ = block(x_gpu) assert output_gpu.device.type == "cuda" # Transfer back to CPU block = block.cpu() output_cpu2, _ = block(x_cpu) assert output_cpu2.device.type == "cpu" def test_state_dict_save_load(self): """Test saving and loading state dict.""" block1 = GammaSingleBlock(d_model=16, hidden_dim=32) block2 = GammaSingleBlock(d_model=16, hidden_dim=32) # Save state dict from block1 state_dict = block1.state_dict() # Load into block2 block2.load_state_dict(state_dict) # They should produce same output x = torch.randn(2, 10, 16) with torch.no_grad(): out1, _ = block1(x) out2, _ = block2(x) assert torch.allclose(out1, out2, atol=1e-6) def test_train_eval_mode_switching(self): """Test switching between train and eval modes.""" block = GammaSingleBlock(d_model=16, hidden_dim=32, dropout=0.5) # Train mode block.train() assert block.training # Eval mode block.eval() assert not block.training # Train mode again block.train() assert block.training def test_different_d_models_and_hidden_dims(self): """Test blocks with various dimensions.""" configs = [ (8, 16), (16, 32), (64, 128), (256, 512), ] for d_model, hidden_dim in configs: block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim) x = torch.randn(2, 10, d_model) output, state = block(x) assert output.shape == (2, 10, d_model) assert state.shape == (2, hidden_dim) if __name__ == "__main__": pytest.main([__file__, "-v"])