|
|
|
|
|
""" |
|
|
Unit tests for training components. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import pytest |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from compact_ai_model.training.train import create_sample_data, TextDataset |
|
|
from compact_ai_model.architecture.model import create_compact_model |
|
|
|
|
|
|
|
|
class TestSampleData: |
|
|
"""Test sample data creation.""" |
|
|
|
|
|
def test_create_sample_data(self): |
|
|
"""Test creating sample training data.""" |
|
|
data = create_sample_data(50) |
|
|
assert len(data) == 50 |
|
|
assert all("text" in item for item in data) |
|
|
assert all(isinstance(item["text"], str) for item in data) |
|
|
|
|
|
def test_sample_data_content(self): |
|
|
"""Test that sample data has reasonable content.""" |
|
|
data = create_sample_data(10) |
|
|
texts = [item["text"] for item in data] |
|
|
|
|
|
|
|
|
assert any("Question:" in text for text in texts) |
|
|
assert any("Answer:" in text for text in texts) |
|
|
|
|
|
|
|
|
class TestTextDataset: |
|
|
"""Test TextDataset functionality.""" |
|
|
|
|
|
def test_dataset_creation(self): |
|
|
"""Test creating a TextDataset.""" |
|
|
data = create_sample_data(100) |
|
|
dataset = TextDataset(data) |
|
|
|
|
|
assert len(dataset) == 100 |
|
|
|
|
|
def test_dataset_item_access(self): |
|
|
"""Test accessing items from dataset.""" |
|
|
data = create_sample_data(10) |
|
|
dataset = TextDataset(data) |
|
|
|
|
|
item = dataset[0] |
|
|
assert "text" in item |
|
|
assert isinstance(item["text"], str) |
|
|
|
|
|
def test_dataset_with_tokenizer(self): |
|
|
"""Test dataset with tokenizer (mock test).""" |
|
|
data = create_sample_data(10) |
|
|
|
|
|
|
|
|
class MockTokenizer: |
|
|
def encode(self, text, max_length=None, truncation=None, padding=None): |
|
|
|
|
|
length = min(len(text), max_length or 100) |
|
|
return list(range(length)) |
|
|
|
|
|
tokenizer = MockTokenizer() |
|
|
dataset = TextDataset(data, tokenizer=tokenizer, max_length=50) |
|
|
|
|
|
item = dataset[0] |
|
|
assert "input_ids" in item |
|
|
assert "attention_mask" in item |
|
|
assert len(item["input_ids"]) <= 50 |
|
|
|
|
|
|
|
|
class TestTrainingIntegration: |
|
|
"""Test training integration (basic smoke tests).""" |
|
|
|
|
|
def test_model_forward_for_training(self): |
|
|
"""Test that model can be used for training-like operations.""" |
|
|
model = create_compact_model("tiny") |
|
|
model.eval() |
|
|
|
|
|
|
|
|
vocab_size = model.model_config.vocab_size |
|
|
batch_size, seq_len = 2, 16 |
|
|
input_ids = torch.randint(0, min(100, vocab_size), (batch_size, seq_len)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, use_thinking=False) |
|
|
|
|
|
assert outputs["logits"].shape == (batch_size, seq_len, vocab_size) |
|
|
|
|
|
def test_thinking_enabled_training(self): |
|
|
"""Test model with thinking enabled for training.""" |
|
|
model = create_compact_model("tiny") |
|
|
model.eval() |
|
|
|
|
|
batch_size, seq_len = 2, 16 |
|
|
vocab_size = model.model_config.vocab_size |
|
|
input_ids = torch.randint(0, min(100, vocab_size), (batch_size, seq_len)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, use_thinking=True, max_reasoning_depth=1) |
|
|
|
|
|
assert "logits" in outputs |
|
|
assert "thinking_results" in outputs |
|
|
assert "final_tokens" in outputs |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__]) |