| | """Unit tests for core architecture components"""
|
| | import pytest
|
| | import torch
|
| | import torch.nn as nn
|
| | from src.models.architecture import (
|
| | ModelConfig,
|
| | RMSNorm,
|
| | RotaryPositionalEmbedding,
|
| | SwiGLU,
|
| | GroupedQueryAttention,
|
| | apply_rotary_pos_emb,
|
| | rotate_half
|
| | )
|
| |
|
| |
|
| | class TestRMSNorm:
|
| | """Test RMSNorm layer"""
|
| |
|
| | def test_initialization(self):
|
| | """Test RMSNorm initialization"""
|
| | norm = RMSNorm(hidden_size=768)
|
| | assert norm.weight.shape == (768,)
|
| | assert torch.allclose(norm.weight, torch.ones(768))
|
| |
|
| | def test_forward_pass(self):
|
| | """Test RMSNorm forward pass"""
|
| | norm = RMSNorm(hidden_size=768)
|
| | x = torch.randn(2, 16, 768)
|
| | output = norm(x)
|
| |
|
| | assert output.shape == x.shape
|
| | assert not torch.isnan(output).any()
|
| | assert not torch.isinf(output).any()
|
| |
|
| | def test_normalization(self):
|
| | """Test that RMSNorm actually normalizes"""
|
| | norm = RMSNorm(hidden_size=768, eps=1e-6)
|
| | x = torch.randn(2, 16, 768) * 100
|
| | output = norm(x)
|
| |
|
| |
|
| | variance = output.pow(2).mean(-1)
|
| | assert torch.allclose(variance, torch.ones_like(variance), atol=1e-5)
|
| |
|
| |
|
| | class TestRotaryPositionalEmbedding:
|
| | """Test Rotary Position Embeddings"""
|
| |
|
| | def test_initialization(self):
|
| | """Test RoPE initialization"""
|
| | rope = RotaryPositionalEmbedding(dim=64, max_position_embeddings=2048)
|
| | assert rope.dim == 64
|
| | assert rope.max_position_embeddings == 2048
|
| | assert rope.inv_freq.shape == (32,)
|
| |
|
| | def test_forward_pass(self):
|
| | """Test RoPE forward pass"""
|
| | rope = RotaryPositionalEmbedding(dim=64)
|
| | x = torch.randn(2, 8, 16, 64)
|
| |
|
| | cos, sin = rope(x, seq_len=16)
|
| |
|
| | assert cos.shape == (1, 1, 16, 64)
|
| | assert sin.shape == (1, 1, 16, 64)
|
| | assert not torch.isnan(cos).any()
|
| | assert not torch.isnan(sin).any()
|
| |
|
| | def test_caching(self):
|
| | """Test that RoPE caches cos/sin values"""
|
| | rope = RotaryPositionalEmbedding(dim=64)
|
| | x = torch.randn(2, 8, 16, 64)
|
| |
|
| | cos1, sin1 = rope(x, seq_len=16)
|
| | cos2, sin2 = rope(x, seq_len=16)
|
| |
|
| |
|
| | assert torch.equal(cos1, cos2)
|
| | assert torch.equal(sin1, sin2)
|
| |
|
| |
|
| | class TestRotaryFunctions:
|
| | """Test rotary embedding utility functions"""
|
| |
|
| | def test_rotate_half(self):
|
| | """Test rotate_half function"""
|
| | x = torch.tensor([[1., 2., 3., 4.]])
|
| | rotated = rotate_half(x)
|
| | expected = torch.tensor([[-3., -4., 1., 2.]])
|
| | assert torch.allclose(rotated, expected)
|
| |
|
| | def test_apply_rotary_pos_emb(self):
|
| | """Test applying rotary embeddings to Q and K"""
|
| | q = torch.randn(2, 8, 16, 64)
|
| | k = torch.randn(2, 8, 16, 64)
|
| | cos = torch.randn(1, 1, 16, 64)
|
| | sin = torch.randn(1, 1, 16, 64)
|
| |
|
| | q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)
|
| |
|
| | assert q_rot.shape == q.shape
|
| | assert k_rot.shape == k.shape
|
| | assert not torch.isnan(q_rot).any()
|
| | assert not torch.isnan(k_rot).any()
|
| |
|
| |
|
| | class TestSwiGLU:
|
| | """Test SwiGLU activation"""
|
| |
|
| | @pytest.fixture
|
| | def config(self):
|
| | return ModelConfig(
|
| | n_embd=768,
|
| | intermediate_size=3072,
|
| | mlp_bias=False
|
| | )
|
| |
|
| | def test_initialization(self, config):
|
| | """Test SwiGLU initialization"""
|
| | swiglu = SwiGLU(config)
|
| | assert swiglu.hidden_size == 768
|
| | assert swiglu.intermediate_size == 3072
|
| | assert swiglu.gate_proj.weight.shape == (3072, 768)
|
| | assert swiglu.up_proj.weight.shape == (3072, 768)
|
| | assert swiglu.down_proj.weight.shape == (768, 3072)
|
| |
|
| | def test_forward_pass(self, config):
|
| | """Test SwiGLU forward pass"""
|
| | swiglu = SwiGLU(config)
|
| | x = torch.randn(2, 16, 768)
|
| | output = swiglu(x)
|
| |
|
| | assert output.shape == x.shape
|
| | assert not torch.isnan(output).any()
|
| |
|
| | def test_no_bias(self, config):
|
| | """Test that bias is not used when mlp_bias=False"""
|
| | swiglu = SwiGLU(config)
|
| | assert swiglu.gate_proj.bias is None
|
| | assert swiglu.up_proj.bias is None
|
| | assert swiglu.down_proj.bias is None
|
| |
|
| |
|
| | class TestGroupedQueryAttention:
|
| | """Test Grouped Query Attention"""
|
| |
|
| | @pytest.fixture
|
| | def config(self):
|
| | return ModelConfig(
|
| | n_embd=768,
|
| | n_head=12,
|
| | n_kv_head=4,
|
| | attention_bias=False,
|
| | attention_dropout=0.0
|
| | )
|
| |
|
| | def test_initialization(self, config):
|
| | """Test GQA initialization"""
|
| | gqa = GroupedQueryAttention(config)
|
| |
|
| | assert gqa.num_heads == 12
|
| | assert gqa.num_kv_heads == 4
|
| | assert gqa.num_kv_groups == 3
|
| | assert gqa.head_dim == 64
|
| |
|
| |
|
| | assert gqa.q_proj.weight.shape == (768, 768)
|
| | assert gqa.k_proj.weight.shape == (256, 768)
|
| | assert gqa.v_proj.weight.shape == (256, 768)
|
| |
|
| | def test_forward_pass(self, config):
|
| | """Test GQA forward pass"""
|
| | gqa = GroupedQueryAttention(config)
|
| | x = torch.randn(2, 16, 768)
|
| |
|
| | output, attn_weights = gqa(x)
|
| |
|
| | assert output.shape == x.shape
|
| | assert not torch.isnan(output).any()
|
| |
|
| | def test_with_attention_mask(self, config):
|
| | """Test GQA with attention mask"""
|
| | gqa = GroupedQueryAttention(config)
|
| | x = torch.randn(2, 16, 768)
|
| | attention_mask = torch.ones(2, 16, dtype=torch.bool)
|
| | attention_mask[:, 8:] = False
|
| |
|
| | output, _ = gqa(x, attention_mask=attention_mask)
|
| |
|
| | assert output.shape == x.shape
|
| | assert not torch.isnan(output).any()
|
| |
|
| | def test_causal_attention(self, config):
|
| | """Test causal (autoregressive) attention"""
|
| | gqa = GroupedQueryAttention(config)
|
| | x = torch.randn(2, 16, 768)
|
| |
|
| | output, attn_weights = gqa(x, is_causal=True)
|
| |
|
| | assert output.shape == x.shape
|
| | assert not torch.isnan(output).any()
|
| |
|
| |
|
| | class TestModelConfig:
|
| | """Test model configuration"""
|
| |
|
| | def test_default_config(self):
|
| | """Test default configuration values"""
|
| | config = ModelConfig()
|
| |
|
| | assert config.vocab_size == 100352
|
| | assert config.n_positions == 8192
|
| | assert config.n_embd == 4096
|
| | assert config.n_layer == 64
|
| | assert config.activation == "swiglu"
|
| | assert config.norm_type == "rmsnorm"
|
| |
|
| | def test_custom_config(self):
|
| | """Test custom configuration"""
|
| | config = ModelConfig(
|
| | vocab_size=50000,
|
| | n_embd=512,
|
| | n_layer=6,
|
| | n_head=8
|
| | )
|
| |
|
| | assert config.vocab_size == 50000
|
| | assert config.n_embd == 512
|
| | assert config.n_layer == 6
|
| | assert config.n_head == 8
|
| |
|
| | def test_gqa_compatibility(self):
|
| | """Test that GQA heads are compatible"""
|
| | config = ModelConfig(n_head=12, n_kv_head=4)
|
| | assert config.n_head % config.n_kv_head == 0
|
| |
|
| |
|
| | gqa = GroupedQueryAttention(config)
|
| | assert gqa.num_kv_groups == 3
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | pytest.main([__file__, "-v"])
|
| |
|