File size: 5,895 Bytes
b9b1e87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
"""
Unit tests for the Compact AI Model with Interleaved Thinking.
"""

import torch
import pytest
import sys
from pathlib import Path

# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))

from compact_ai_model.architecture.model import create_compact_model, CompactAIModel
from compact_ai_model.configs.config import Config, ModelConfig, InterleavedThinkingConfig


class TestModelCreation:
    """Test model creation and basic properties."""

    def test_create_tiny_model(self):
        """Test creating a tiny model."""
        model = create_compact_model("tiny")
        assert model is not None
        assert isinstance(model, CompactAIModel)

        num_params = model.get_num_params()
        assert num_params < 100_000_000, f"Tiny model too large: {num_params}"

    def test_create_small_model(self):
        """Test creating a small model."""
        model = create_compact_model("small")
        assert model is not None
        assert isinstance(model, CompactAIModel)

        num_params = model.get_num_params()
        assert num_params < 250_000_000, f"Small model too large: {num_params}"

    def test_create_medium_model(self):
        """Test creating a medium model."""
        model = create_compact_model("medium")
        assert model is not None
        assert isinstance(model, CompactAIModel)

        num_params = model.get_num_params()
        assert num_params < 400_000_000, f"Medium model too large: {num_params}"

    def test_model_config(self):
        """Test model configuration."""
        model_config = ModelConfig(dim=128, layers=4, heads=4, vocab_size=1000)
        thinking_config = InterleavedThinkingConfig(max_reasoning_paths=2, reasoning_depth=3)
        model = CompactAIModel(model_config, thinking_config)

        assert model.model_config.dim == 128
        assert model.model_config.layers == 4
        assert model.thinking_config.max_reasoning_paths == 2


class TestForwardPass:
    """Test forward pass functionality."""

    def test_forward_without_thinking(self):
        """Test basic forward pass without thinking."""
        model = create_compact_model("tiny")
        model.eval()

        batch_size, seq_len, vocab_size = 2, 16, model.model_config.vocab_size
        input_ids = torch.randint(0, min(1000, vocab_size), (batch_size, seq_len))

        with torch.no_grad():
            outputs = model(input_ids, use_thinking=False)

        assert "logits" in outputs
        assert outputs["logits"].shape == (batch_size, seq_len, vocab_size)
        assert outputs["thinking_results"] is None

    def test_forward_with_thinking(self):
        """Test forward pass with interleaved thinking."""
        model = create_compact_model("tiny")
        model.eval()

        batch_size, seq_len, vocab_size = 2, 16, model.model_config.vocab_size
        input_ids = torch.randint(0, min(1000, vocab_size), (batch_size, seq_len))

        with torch.no_grad():
            outputs = model(input_ids, use_thinking=True, max_reasoning_depth=2)

        assert "logits" in outputs
        assert outputs["logits"].shape == (batch_size, seq_len, vocab_size)
        assert "thinking_results" in outputs
        assert outputs["thinking_results"] is not None
        assert "final_tokens" in outputs
        assert isinstance(outputs["final_tokens"], int)


class TestInterleavedThinking:
    """Test interleaved thinking mechanism."""

    def test_thinking_outputs_structure(self):
        """Test that thinking outputs have correct structure."""
        model_config = ModelConfig(dim=64, layers=2, heads=4, vocab_size=1000)
        thinking_config = InterleavedThinkingConfig(max_reasoning_paths=2, reasoning_depth=2)
        model = CompactAIModel(model_config, thinking_config)
        model.eval()

        input_ids = torch.randint(0, 1000, (1, 8))

        with torch.no_grad():
            outputs = model(input_ids, use_thinking=True, max_reasoning_depth=2)

        thinking_results = outputs["thinking_results"]
        assert isinstance(thinking_results, list)
        assert len(thinking_results) > 0

        first_result = thinking_results[0]
        assert "path_logits" in first_result
        assert "confidence_scores" in first_result
        assert "complexity" in first_result


class TestMemoryUsage:
    """Test memory usage constraints."""

    def test_memory_efficiency(self):
        """Test that model stays within memory limits."""
        model = create_compact_model("tiny")

        # This is a basic test - in real scenarios, you'd monitor actual memory usage
        num_params = model.get_num_params()
        # Assuming 4 bytes per parameter (float32)
        estimated_memory_mb = (num_params * 4) / (1024 * 1024)

        # Tiny model should be under 200MB
        assert estimated_memory_mb < 200, f"Model memory estimate too high: {estimated_memory_mb:.1f}MB"


class TestConfiguration:
    """Test configuration loading and validation."""

    def test_config_validation(self):
        """Test that configurations are valid."""
        # Test balanced config
        config = Config.get_balanced_config()
        assert config.model.dim > 0
        assert config.thinking.max_reasoning_paths > 0
        assert 0 <= config.thinking.early_stop_threshold <= 1

    def test_config_serialization(self):
        """Test config save/load."""
        from compact_ai_model.configs.config import load_config_from_dict, save_config_to_dict

        config = Config.get_balanced_config()
        config_dict = save_config_to_dict(config)
        loaded_config = load_config_from_dict(config_dict)

        assert loaded_config.model.dim == config.model.dim
        assert loaded_config.thinking.max_reasoning_paths == config.thinking.max_reasoning_paths


if __name__ == "__main__":
    pytest.main([__file__])