TaoNet-mini-T2 / code /Taotern_SSM /tests /test_block.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
"""Tests for Gamma SSM Block."""
import torch
import pytest
from gamma_space_model import GammaSingleBlock
class TestGammaSingleBlockInitialization:
"""Test GammaSingleBlock initialization."""
def test_direct_parameter_init(self):
"""Test GammaSingleBlock with direct parameters (no config)."""
block = GammaSingleBlock(
d_model=16,
hidden_dim=32,
delta_t=0.1,
kernel_length=4,
A_type="tridiagonal",
prenorm=True,
residual_scale=1.0,
dropout=0.0,
)
assert block.d_model == 16
assert block.prenorm is True
assert block.residual_scale == 1.0
assert block.dropout_p == 0.0
def test_default_parameters(self):
"""Test that default parameters are set correctly."""
block = GammaSingleBlock(d_model=16, hidden_dim=32)
assert block.d_model == 16
assert block.prenorm is True
assert block.residual_scale == 1.0
assert block.dropout_p == 0.0
assert block.ssm.delta_t == 0.1
def test_ssm_instantiation(self):
"""Test that SSM block is correctly instantiated."""
block = GammaSingleBlock(
d_model=16,
hidden_dim=32,
delta_t=0.2,
A_type="tridiagonal",
)
assert block.ssm.state_dim == 16
assert block.ssm.hidden_dim == 32
assert block.ssm.delta_t == 0.2
assert block.ssm.A_type == "tridiagonal"
class TestGammaSingleBlockForwardPass:
"""Test GammaSingleBlock forward pass."""
def test_forward_output_shape(self):
"""Test that forward pass produces correct output shape."""
batch_size, seq_len, d_model = 4, 32, 16
hidden_dim = 32
block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim)
x = torch.randn(batch_size, seq_len, d_model)
output, final_state = block(x)
assert output.shape == (batch_size, seq_len, d_model)
assert final_state.shape == (batch_size, hidden_dim)
def test_forward_with_initial_state(self):
"""Test forward pass with provided initial state."""
batch_size, seq_len, d_model, hidden_dim = 2, 16, 8, 16
block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim)
x = torch.randn(batch_size, seq_len, d_model)
initial_state = torch.zeros(batch_size, hidden_dim)
output1, final_state1 = block(x, state=initial_state)
output2, final_state2 = block(x, state=None)
# Should produce the same results since initial_state defaults to zeros
assert torch.allclose(output1, output2, atol=1e-5)
class TestGammaSingleBlockNormalization:
"""Test GammaSingleBlock normalization (prenorm vs postnorm)."""
def test_prenorm_configuration(self):
"""Test prenorm configuration."""
block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=True)
x = torch.randn(2, 10, 16)
output, _ = block(x)
assert output.shape == (2, 10, 16)
def test_postnorm_configuration(self):
"""Test postnorm configuration."""
block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=False)
x = torch.randn(2, 10, 16)
output, _ = block(x)
assert output.shape == (2, 10, 16)
def test_prenorm_vs_postnorm_outputs_differ(self):
"""Test that prenorm and postnorm produce different outputs."""
x = torch.randn(2, 10, 16)
prenorm_block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=True)
postnorm_block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=False)
output_pre, _ = prenorm_block(x)
output_post, _ = postnorm_block(x)
# Outputs should differ but have same shape
assert output_pre.shape == output_post.shape
assert not torch.allclose(output_pre, output_post)
class TestGammaSingleBlockResidualConnection:
"""Test GammaSingleBlock residual connection."""
def test_residual_with_scale_1(self):
"""Test residual connection with scale=1.0."""
block = GammaSingleBlock(
d_model=16,
hidden_dim=32,
residual_scale=1.0,
prenorm=True,
)
x = torch.randn(2, 10, 16)
output, _ = block(x)
# Output should be x + SSM_output (approximately close to x)
assert torch.allclose(output, x, atol=2.0)
def test_residual_with_scale_0(self):
"""Test residual connection with scale=0.0 (no residual)."""
block = GammaSingleBlock(
d_model=16,
hidden_dim=32,
residual_scale=0.0,
prenorm=True,
)
x = torch.randn(2, 10, 16)
output, _ = block(x)
# Output should be purely from norm + SSM (not affected by input)
# It will still be different from x
assert not torch.allclose(output, x)
def test_residual_scale_effect(self):
"""Test that residual_scale parameter affects output."""
x = torch.randn(2, 10, 16)
block1 = GammaSingleBlock(
d_model=16,
hidden_dim=32,
residual_scale=0.5,
prenorm=True,
)
block2 = GammaSingleBlock(
d_model=16,
hidden_dim=32,
residual_scale=2.0,
prenorm=True,
)
output1, _ = block1(x)
output2, _ = block2(x)
# Different scales should produce different outputs
assert not torch.allclose(output1, output2)
class TestGammaSingleBlockDropout:
"""Test GammaSingleBlock dropout."""
def test_dropout_train_mode(self):
"""Test that dropout is applied during training."""
block = GammaSingleBlock(
d_model=16,
hidden_dim=32,
dropout=0.5,
)
block.train()
x = torch.randn(2, 10, 16)
# Multiple forward passes should give different results due to dropout
output1, _ = block(x)
output2, _ = block(x)
assert not torch.allclose(output1, output2)
def test_dropout_eval_mode(self):
"""Test that dropout is not applied during evaluation."""
block = GammaSingleBlock(
d_model=16,
hidden_dim=32,
dropout=0.5,
)
block.eval()
x = torch.randn(2, 10, 16)
# Multiple forward passes should give same results in eval mode
output1, _ = block(x)
output2, _ = block(x)
assert torch.allclose(output1, output2)
def test_no_dropout_with_zero_dropout_rate(self):
"""Test that no dropout is applied when dropout=0."""
block = GammaSingleBlock(
d_model=16,
hidden_dim=32,
dropout=0.0,
)
# Should not have dropout layer
assert block.dropout is None
class TestGammaSingleBlockMasking:
"""Test GammaSingleBlock masking functionality."""
def test_forward_with_mask(self):
"""Test forward pass with masking."""
batch_size, seq_len, d_model = 2, 10, 16
hidden_dim = 32
block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim)
x = torch.randn(batch_size, seq_len, d_model)
mask = torch.ones(batch_size, seq_len, dtype=torch.bool)
mask[1, 5:] = False
output, _ = block(x, mask=mask)
assert output.shape == (batch_size, seq_len, d_model)
class TestGammaSingleBlockGradients:
"""Test gradient flow through GammaSingleBlock."""
def test_backward_pass(self):
"""Test that gradients flow correctly."""
batch_size, seq_len, d_model = 2, 10, 16
hidden_dim = 32
block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim)
x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
output, _ = block(x)
loss = output.sum()
loss.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
# Check that block parameters have gradients
for param in block.parameters():
if param.requires_grad:
assert param.grad is not None
def test_gradient_flow_prenorm(self):
"""Test gradient flow with prenorm."""
block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=True)
x = torch.randn(2, 10, 16, requires_grad=True)
output, _ = block(x)
loss = output.sum()
loss.backward()
assert x.grad is not None
def test_gradient_flow_postnorm(self):
"""Test gradient flow with postnorm."""
block = GammaSingleBlock(d_model=16, hidden_dim=32, prenorm=False)
x = torch.randn(2, 10, 16, requires_grad=True)
output, _ = block(x)
loss = output.sum()
loss.backward()
assert x.grad is not None
class TestGammaSingleBlockIntegration:
"""Integration tests for GammaSingleBlock."""
def test_stacked_blocks(self):
"""Test stacking multiple blocks together."""
d_model, hidden_dim = 16, 32
num_blocks = 3
blocks = [
GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim)
for _ in range(num_blocks)
]
x = torch.randn(2, 10, d_model)
states = []
# Forward through blocks
for block in blocks:
x, state = block(x)
states.append(state)
assert x.shape == (2, 10, d_model)
assert len(states) == num_blocks
def test_device_transfer(self):
"""Test that block can be transferred between devices."""
block = GammaSingleBlock(d_model=16, hidden_dim=32)
# Test on CPU
x_cpu = torch.randn(2, 10, 16)
output_cpu, _ = block(x_cpu)
assert output_cpu.device.type == "cpu"
if torch.cuda.is_available():
# Transfer to GPU
block = block.cuda()
x_gpu = torch.randn(2, 10, 16).cuda()
output_gpu, _ = block(x_gpu)
assert output_gpu.device.type == "cuda"
# Transfer back to CPU
block = block.cpu()
output_cpu2, _ = block(x_cpu)
assert output_cpu2.device.type == "cpu"
def test_state_dict_save_load(self):
"""Test saving and loading state dict."""
block1 = GammaSingleBlock(d_model=16, hidden_dim=32)
block2 = GammaSingleBlock(d_model=16, hidden_dim=32)
# Save state dict from block1
state_dict = block1.state_dict()
# Load into block2
block2.load_state_dict(state_dict)
# They should produce same output
x = torch.randn(2, 10, 16)
with torch.no_grad():
out1, _ = block1(x)
out2, _ = block2(x)
assert torch.allclose(out1, out2, atol=1e-6)
def test_train_eval_mode_switching(self):
"""Test switching between train and eval modes."""
block = GammaSingleBlock(d_model=16, hidden_dim=32, dropout=0.5)
# Train mode
block.train()
assert block.training
# Eval mode
block.eval()
assert not block.training
# Train mode again
block.train()
assert block.training
def test_different_d_models_and_hidden_dims(self):
"""Test blocks with various dimensions."""
configs = [
(8, 16),
(16, 32),
(64, 128),
(256, 512),
]
for d_model, hidden_dim in configs:
block = GammaSingleBlock(d_model=d_model, hidden_dim=hidden_dim)
x = torch.randn(2, 10, d_model)
output, state = block(x)
assert output.shape == (2, 10, d_model)
assert state.shape == (2, hidden_dim)
if __name__ == "__main__":
pytest.main([__file__, "-v"])