#!/usr/bin/env python3 """ Unit tests for training components. """ import torch import pytest import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from compact_ai_model.training.train import create_sample_data, TextDataset from compact_ai_model.architecture.model import create_compact_model class TestSampleData: """Test sample data creation.""" def test_create_sample_data(self): """Test creating sample training data.""" data = create_sample_data(50) assert len(data) == 50 assert all("text" in item for item in data) assert all(isinstance(item["text"], str) for item in data) def test_sample_data_content(self): """Test that sample data has reasonable content.""" data = create_sample_data(10) texts = [item["text"] for item in data] # Check that data contains various templates assert any("Question:" in text for text in texts) assert any("Answer:" in text for text in texts) class TestTextDataset: """Test TextDataset functionality.""" def test_dataset_creation(self): """Test creating a TextDataset.""" data = create_sample_data(100) dataset = TextDataset(data) assert len(dataset) == 100 def test_dataset_item_access(self): """Test accessing items from dataset.""" data = create_sample_data(10) dataset = TextDataset(data) item = dataset[0] assert "text" in item assert isinstance(item["text"], str) def test_dataset_with_tokenizer(self): """Test dataset with tokenizer (mock test).""" data = create_sample_data(10) # Mock tokenizer class MockTokenizer: def encode(self, text, max_length=None, truncation=None, padding=None): # Simple mock that returns token IDs based on text length length = min(len(text), max_length or 100) return list(range(length)) tokenizer = MockTokenizer() dataset = TextDataset(data, tokenizer=tokenizer, max_length=50) item = dataset[0] assert "input_ids" in item assert "attention_mask" in item assert len(item["input_ids"]) <= 50 class TestTrainingIntegration: """Test training integration (basic smoke tests).""" def test_model_forward_for_training(self): """Test that model can be used for training-like operations.""" model = create_compact_model("tiny") model.eval() # Create batch vocab_size = model.model_config.vocab_size batch_size, seq_len = 2, 16 input_ids = torch.randint(0, min(100, vocab_size), (batch_size, seq_len)) with torch.no_grad(): outputs = model(input_ids, use_thinking=False) assert outputs["logits"].shape == (batch_size, seq_len, vocab_size) def test_thinking_enabled_training(self): """Test model with thinking enabled for training.""" model = create_compact_model("tiny") model.eval() batch_size, seq_len = 2, 16 vocab_size = model.model_config.vocab_size input_ids = torch.randint(0, min(100, vocab_size), (batch_size, seq_len)) with torch.no_grad(): outputs = model(input_ids, use_thinking=True, max_reasoning_depth=1) assert "logits" in outputs assert "thinking_results" in outputs assert "final_tokens" in outputs if __name__ == "__main__": pytest.main([__file__])