Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""Unit tests for core architecture components"""
import pytest
import torch
import torch.nn as nn
from src.models.architecture import (
ModelConfig,
RMSNorm,
RotaryPositionalEmbedding,
SwiGLU,
GroupedQueryAttention,
apply_rotary_pos_emb,
rotate_half
)
class TestRMSNorm:
"""Test RMSNorm layer"""
def test_initialization(self):
"""Test RMSNorm initialization"""
norm = RMSNorm(hidden_size=768)
assert norm.weight.shape == (768,)
assert torch.allclose(norm.weight, torch.ones(768))
def test_forward_pass(self):
"""Test RMSNorm forward pass"""
norm = RMSNorm(hidden_size=768)
x = torch.randn(2, 16, 768)
output = norm(x)
assert output.shape == x.shape
assert not torch.isnan(output).any()
assert not torch.isinf(output).any()
def test_normalization(self):
"""Test that RMSNorm actually normalizes"""
norm = RMSNorm(hidden_size=768, eps=1e-6)
x = torch.randn(2, 16, 768) * 100 # Large scale
output = norm(x)
# Check variance is close to 1
variance = output.pow(2).mean(-1)
assert torch.allclose(variance, torch.ones_like(variance), atol=1e-5)
class TestRotaryPositionalEmbedding:
"""Test Rotary Position Embeddings"""
def test_initialization(self):
"""Test RoPE initialization"""
rope = RotaryPositionalEmbedding(dim=64, max_position_embeddings=2048)
assert rope.dim == 64
assert rope.max_position_embeddings == 2048
assert rope.inv_freq.shape == (32,) # dim // 2
def test_forward_pass(self):
"""Test RoPE forward pass"""
rope = RotaryPositionalEmbedding(dim=64)
x = torch.randn(2, 8, 16, 64) # batch, heads, seq, dim
cos, sin = rope(x, seq_len=16)
assert cos.shape == (1, 1, 16, 64)
assert sin.shape == (1, 1, 16, 64)
assert not torch.isnan(cos).any()
assert not torch.isnan(sin).any()
def test_caching(self):
"""Test that RoPE caches cos/sin values"""
rope = RotaryPositionalEmbedding(dim=64)
x = torch.randn(2, 8, 16, 64)
cos1, sin1 = rope(x, seq_len=16)
cos2, sin2 = rope(x, seq_len=16)
# Should use cached values
assert torch.equal(cos1, cos2)
assert torch.equal(sin1, sin2)
class TestRotaryFunctions:
"""Test rotary embedding utility functions"""
def test_rotate_half(self):
"""Test rotate_half function"""
x = torch.tensor([[1., 2., 3., 4.]])
rotated = rotate_half(x)
expected = torch.tensor([[-3., -4., 1., 2.]])
assert torch.allclose(rotated, expected)
def test_apply_rotary_pos_emb(self):
"""Test applying rotary embeddings to Q and K"""
q = torch.randn(2, 8, 16, 64) # batch, heads, seq, dim
k = torch.randn(2, 8, 16, 64)
cos = torch.randn(1, 1, 16, 64)
sin = torch.randn(1, 1, 16, 64)
q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)
assert q_rot.shape == q.shape
assert k_rot.shape == k.shape
assert not torch.isnan(q_rot).any()
assert not torch.isnan(k_rot).any()
class TestSwiGLU:
"""Test SwiGLU activation"""
@pytest.fixture
def config(self):
return ModelConfig(
n_embd=768,
intermediate_size=3072,
mlp_bias=False
)
def test_initialization(self, config):
"""Test SwiGLU initialization"""
swiglu = SwiGLU(config)
assert swiglu.hidden_size == 768
assert swiglu.intermediate_size == 3072
assert swiglu.gate_proj.weight.shape == (3072, 768)
assert swiglu.up_proj.weight.shape == (3072, 768)
assert swiglu.down_proj.weight.shape == (768, 3072)
def test_forward_pass(self, config):
"""Test SwiGLU forward pass"""
swiglu = SwiGLU(config)
x = torch.randn(2, 16, 768)
output = swiglu(x)
assert output.shape == x.shape
assert not torch.isnan(output).any()
def test_no_bias(self, config):
"""Test that bias is not used when mlp_bias=False"""
swiglu = SwiGLU(config)
assert swiglu.gate_proj.bias is None
assert swiglu.up_proj.bias is None
assert swiglu.down_proj.bias is None
class TestGroupedQueryAttention:
"""Test Grouped Query Attention"""
@pytest.fixture
def config(self):
return ModelConfig(
n_embd=768,
n_head=12,
n_kv_head=4, # GQA with 3 groups
attention_bias=False,
attention_dropout=0.0
)
def test_initialization(self, config):
"""Test GQA initialization"""
gqa = GroupedQueryAttention(config)
assert gqa.num_heads == 12
assert gqa.num_kv_heads == 4
assert gqa.num_kv_groups == 3
assert gqa.head_dim == 64 # 768 / 12
# Check projection sizes
assert gqa.q_proj.weight.shape == (768, 768) # num_heads * head_dim
assert gqa.k_proj.weight.shape == (256, 768) # num_kv_heads * head_dim
assert gqa.v_proj.weight.shape == (256, 768)
def test_forward_pass(self, config):
"""Test GQA forward pass"""
gqa = GroupedQueryAttention(config)
x = torch.randn(2, 16, 768) # batch, seq, hidden
output, attn_weights = gqa(x)
assert output.shape == x.shape
assert not torch.isnan(output).any()
def test_with_attention_mask(self, config):
"""Test GQA with attention mask"""
gqa = GroupedQueryAttention(config)
x = torch.randn(2, 16, 768)
attention_mask = torch.ones(2, 16, dtype=torch.bool)
attention_mask[:, 8:] = False # Mask last half
output, _ = gqa(x, attention_mask=attention_mask)
assert output.shape == x.shape
assert not torch.isnan(output).any()
def test_causal_attention(self, config):
"""Test causal (autoregressive) attention"""
gqa = GroupedQueryAttention(config)
x = torch.randn(2, 16, 768)
output, attn_weights = gqa(x, is_causal=True)
assert output.shape == x.shape
assert not torch.isnan(output).any()
class TestModelConfig:
"""Test model configuration"""
def test_default_config(self):
"""Test default configuration values"""
config = ModelConfig()
assert config.vocab_size == 100352
assert config.n_positions == 8192
assert config.n_embd == 4096
assert config.n_layer == 64
assert config.activation == "swiglu"
assert config.norm_type == "rmsnorm"
def test_custom_config(self):
"""Test custom configuration"""
config = ModelConfig(
vocab_size=50000,
n_embd=512,
n_layer=6,
n_head=8
)
assert config.vocab_size == 50000
assert config.n_embd == 512
assert config.n_layer == 6
assert config.n_head == 8
def test_gqa_compatibility(self):
"""Test that GQA heads are compatible"""
config = ModelConfig(n_head=12, n_kv_head=4)
assert config.n_head % config.n_kv_head == 0
# This should work - 12 heads, 4 KV heads = 3 groups
gqa = GroupedQueryAttention(config)
assert gqa.num_kv_groups == 3
if __name__ == "__main__":
pytest.main([__file__, "-v"])