likhonsheikh's picture
Upload folder using huggingface_hub
b9b1e87 verified
#!/usr/bin/env python3
"""
Unit tests for training components.
"""
import torch
import pytest
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from compact_ai_model.training.train import create_sample_data, TextDataset
from compact_ai_model.architecture.model import create_compact_model
class TestSampleData:
"""Test sample data creation."""
def test_create_sample_data(self):
"""Test creating sample training data."""
data = create_sample_data(50)
assert len(data) == 50
assert all("text" in item for item in data)
assert all(isinstance(item["text"], str) for item in data)
def test_sample_data_content(self):
"""Test that sample data has reasonable content."""
data = create_sample_data(10)
texts = [item["text"] for item in data]
# Check that data contains various templates
assert any("Question:" in text for text in texts)
assert any("Answer:" in text for text in texts)
class TestTextDataset:
"""Test TextDataset functionality."""
def test_dataset_creation(self):
"""Test creating a TextDataset."""
data = create_sample_data(100)
dataset = TextDataset(data)
assert len(dataset) == 100
def test_dataset_item_access(self):
"""Test accessing items from dataset."""
data = create_sample_data(10)
dataset = TextDataset(data)
item = dataset[0]
assert "text" in item
assert isinstance(item["text"], str)
def test_dataset_with_tokenizer(self):
"""Test dataset with tokenizer (mock test)."""
data = create_sample_data(10)
# Mock tokenizer
class MockTokenizer:
def encode(self, text, max_length=None, truncation=None, padding=None):
# Simple mock that returns token IDs based on text length
length = min(len(text), max_length or 100)
return list(range(length))
tokenizer = MockTokenizer()
dataset = TextDataset(data, tokenizer=tokenizer, max_length=50)
item = dataset[0]
assert "input_ids" in item
assert "attention_mask" in item
assert len(item["input_ids"]) <= 50
class TestTrainingIntegration:
"""Test training integration (basic smoke tests)."""
def test_model_forward_for_training(self):
"""Test that model can be used for training-like operations."""
model = create_compact_model("tiny")
model.eval()
# Create batch
vocab_size = model.model_config.vocab_size
batch_size, seq_len = 2, 16
input_ids = torch.randint(0, min(100, vocab_size), (batch_size, seq_len))
with torch.no_grad():
outputs = model(input_ids, use_thinking=False)
assert outputs["logits"].shape == (batch_size, seq_len, vocab_size)
def test_thinking_enabled_training(self):
"""Test model with thinking enabled for training."""
model = create_compact_model("tiny")
model.eval()
batch_size, seq_len = 2, 16
vocab_size = model.model_config.vocab_size
input_ids = torch.randint(0, min(100, vocab_size), (batch_size, seq_len))
with torch.no_grad():
outputs = model(input_ids, use_thinking=True, max_reasoning_depth=1)
assert "logits" in outputs
assert "thinking_results" in outputs
assert "final_tokens" in outputs
if __name__ == "__main__":
pytest.main([__file__])