|
|
|
|
|
""" |
|
|
Unit tests for the Compact AI Model with Interleaved Thinking. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import pytest |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from compact_ai_model.architecture.model import create_compact_model, CompactAIModel |
|
|
from compact_ai_model.configs.config import Config, ModelConfig, InterleavedThinkingConfig |
|
|
|
|
|
|
|
|
class TestModelCreation: |
|
|
"""Test model creation and basic properties.""" |
|
|
|
|
|
def test_create_tiny_model(self): |
|
|
"""Test creating a tiny model.""" |
|
|
model = create_compact_model("tiny") |
|
|
assert model is not None |
|
|
assert isinstance(model, CompactAIModel) |
|
|
|
|
|
num_params = model.get_num_params() |
|
|
assert num_params < 100_000_000, f"Tiny model too large: {num_params}" |
|
|
|
|
|
def test_create_small_model(self): |
|
|
"""Test creating a small model.""" |
|
|
model = create_compact_model("small") |
|
|
assert model is not None |
|
|
assert isinstance(model, CompactAIModel) |
|
|
|
|
|
num_params = model.get_num_params() |
|
|
assert num_params < 250_000_000, f"Small model too large: {num_params}" |
|
|
|
|
|
def test_create_medium_model(self): |
|
|
"""Test creating a medium model.""" |
|
|
model = create_compact_model("medium") |
|
|
assert model is not None |
|
|
assert isinstance(model, CompactAIModel) |
|
|
|
|
|
num_params = model.get_num_params() |
|
|
assert num_params < 400_000_000, f"Medium model too large: {num_params}" |
|
|
|
|
|
def test_model_config(self): |
|
|
"""Test model configuration.""" |
|
|
model_config = ModelConfig(dim=128, layers=4, heads=4, vocab_size=1000) |
|
|
thinking_config = InterleavedThinkingConfig(max_reasoning_paths=2, reasoning_depth=3) |
|
|
model = CompactAIModel(model_config, thinking_config) |
|
|
|
|
|
assert model.model_config.dim == 128 |
|
|
assert model.model_config.layers == 4 |
|
|
assert model.thinking_config.max_reasoning_paths == 2 |
|
|
|
|
|
|
|
|
class TestForwardPass: |
|
|
"""Test forward pass functionality.""" |
|
|
|
|
|
def test_forward_without_thinking(self): |
|
|
"""Test basic forward pass without thinking.""" |
|
|
model = create_compact_model("tiny") |
|
|
model.eval() |
|
|
|
|
|
batch_size, seq_len, vocab_size = 2, 16, model.model_config.vocab_size |
|
|
input_ids = torch.randint(0, min(1000, vocab_size), (batch_size, seq_len)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, use_thinking=False) |
|
|
|
|
|
assert "logits" in outputs |
|
|
assert outputs["logits"].shape == (batch_size, seq_len, vocab_size) |
|
|
assert outputs["thinking_results"] is None |
|
|
|
|
|
def test_forward_with_thinking(self): |
|
|
"""Test forward pass with interleaved thinking.""" |
|
|
model = create_compact_model("tiny") |
|
|
model.eval() |
|
|
|
|
|
batch_size, seq_len, vocab_size = 2, 16, model.model_config.vocab_size |
|
|
input_ids = torch.randint(0, min(1000, vocab_size), (batch_size, seq_len)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, use_thinking=True, max_reasoning_depth=2) |
|
|
|
|
|
assert "logits" in outputs |
|
|
assert outputs["logits"].shape == (batch_size, seq_len, vocab_size) |
|
|
assert "thinking_results" in outputs |
|
|
assert outputs["thinking_results"] is not None |
|
|
assert "final_tokens" in outputs |
|
|
assert isinstance(outputs["final_tokens"], int) |
|
|
|
|
|
|
|
|
class TestInterleavedThinking: |
|
|
"""Test interleaved thinking mechanism.""" |
|
|
|
|
|
def test_thinking_outputs_structure(self): |
|
|
"""Test that thinking outputs have correct structure.""" |
|
|
model_config = ModelConfig(dim=64, layers=2, heads=4, vocab_size=1000) |
|
|
thinking_config = InterleavedThinkingConfig(max_reasoning_paths=2, reasoning_depth=2) |
|
|
model = CompactAIModel(model_config, thinking_config) |
|
|
model.eval() |
|
|
|
|
|
input_ids = torch.randint(0, 1000, (1, 8)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, use_thinking=True, max_reasoning_depth=2) |
|
|
|
|
|
thinking_results = outputs["thinking_results"] |
|
|
assert isinstance(thinking_results, list) |
|
|
assert len(thinking_results) > 0 |
|
|
|
|
|
first_result = thinking_results[0] |
|
|
assert "path_logits" in first_result |
|
|
assert "confidence_scores" in first_result |
|
|
assert "complexity" in first_result |
|
|
|
|
|
|
|
|
class TestMemoryUsage: |
|
|
"""Test memory usage constraints.""" |
|
|
|
|
|
def test_memory_efficiency(self): |
|
|
"""Test that model stays within memory limits.""" |
|
|
model = create_compact_model("tiny") |
|
|
|
|
|
|
|
|
num_params = model.get_num_params() |
|
|
|
|
|
estimated_memory_mb = (num_params * 4) / (1024 * 1024) |
|
|
|
|
|
|
|
|
assert estimated_memory_mb < 200, f"Model memory estimate too high: {estimated_memory_mb:.1f}MB" |
|
|
|
|
|
|
|
|
class TestConfiguration: |
|
|
"""Test configuration loading and validation.""" |
|
|
|
|
|
def test_config_validation(self): |
|
|
"""Test that configurations are valid.""" |
|
|
|
|
|
config = Config.get_balanced_config() |
|
|
assert config.model.dim > 0 |
|
|
assert config.thinking.max_reasoning_paths > 0 |
|
|
assert 0 <= config.thinking.early_stop_threshold <= 1 |
|
|
|
|
|
def test_config_serialization(self): |
|
|
"""Test config save/load.""" |
|
|
from compact_ai_model.configs.config import load_config_from_dict, save_config_to_dict |
|
|
|
|
|
config = Config.get_balanced_config() |
|
|
config_dict = save_config_to_dict(config) |
|
|
loaded_config = load_config_from_dict(config_dict) |
|
|
|
|
|
assert loaded_config.model.dim == config.model.dim |
|
|
assert loaded_config.thinking.max_reasoning_paths == config.thinking.max_reasoning_paths |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__]) |