File size: 3,590 Bytes
b9b1e87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
"""
Unit tests for training components.
"""

import torch
import pytest
import sys
from pathlib import Path

# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))

from compact_ai_model.training.train import create_sample_data, TextDataset
from compact_ai_model.architecture.model import create_compact_model


class TestSampleData:
    """Test sample data creation."""

    def test_create_sample_data(self):
        """Test creating sample training data."""
        data = create_sample_data(50)
        assert len(data) == 50
        assert all("text" in item for item in data)
        assert all(isinstance(item["text"], str) for item in data)

    def test_sample_data_content(self):
        """Test that sample data has reasonable content."""
        data = create_sample_data(10)
        texts = [item["text"] for item in data]

        # Check that data contains various templates
        assert any("Question:" in text for text in texts)
        assert any("Answer:" in text for text in texts)


class TestTextDataset:
    """Test TextDataset functionality."""

    def test_dataset_creation(self):
        """Test creating a TextDataset."""
        data = create_sample_data(100)
        dataset = TextDataset(data)

        assert len(dataset) == 100

    def test_dataset_item_access(self):
        """Test accessing items from dataset."""
        data = create_sample_data(10)
        dataset = TextDataset(data)

        item = dataset[0]
        assert "text" in item
        assert isinstance(item["text"], str)

    def test_dataset_with_tokenizer(self):
        """Test dataset with tokenizer (mock test)."""
        data = create_sample_data(10)

        # Mock tokenizer
        class MockTokenizer:
            def encode(self, text, max_length=None, truncation=None, padding=None):
                # Simple mock that returns token IDs based on text length
                length = min(len(text), max_length or 100)
                return list(range(length))

        tokenizer = MockTokenizer()
        dataset = TextDataset(data, tokenizer=tokenizer, max_length=50)

        item = dataset[0]
        assert "input_ids" in item
        assert "attention_mask" in item
        assert len(item["input_ids"]) <= 50


class TestTrainingIntegration:
    """Test training integration (basic smoke tests)."""

    def test_model_forward_for_training(self):
        """Test that model can be used for training-like operations."""
        model = create_compact_model("tiny")
        model.eval()

        # Create batch
        vocab_size = model.model_config.vocab_size
        batch_size, seq_len = 2, 16
        input_ids = torch.randint(0, min(100, vocab_size), (batch_size, seq_len))

        with torch.no_grad():
            outputs = model(input_ids, use_thinking=False)

        assert outputs["logits"].shape == (batch_size, seq_len, vocab_size)

    def test_thinking_enabled_training(self):
        """Test model with thinking enabled for training."""
        model = create_compact_model("tiny")
        model.eval()

        batch_size, seq_len = 2, 16
        vocab_size = model.model_config.vocab_size
        input_ids = torch.randint(0, min(100, vocab_size), (batch_size, seq_len))

        with torch.no_grad():
            outputs = model(input_ids, use_thinking=True, max_reasoning_depth=1)

        assert "logits" in outputs
        assert "thinking_results" in outputs
        assert "final_tokens" in outputs


if __name__ == "__main__":
    pytest.main([__file__])