"""Unit tests for dataset loading""" import pytest import torch from src.data.datasets import DatasetConfig, DATASET_CONFIGS class TestDatasetConfig: """Test dataset configuration""" def test_default_config(self): """Test default dataset configuration""" config = DatasetConfig() assert config.name == "wikitext" assert config.max_length == 512 assert config.tokenizer_name == "gpt2" def test_custom_config(self): """Test custom dataset configuration""" config = DatasetConfig( name="custom", max_length=1024, tokenizer_name="facebook/opt-125m", streaming=True ) assert config.name == "custom" assert config.max_length == 1024 assert config.streaming is True def test_validation(self): """Test configuration validation""" config = DatasetConfig(max_length=512) assert config.max_length == 512 assert config.min_length == 10 assert config.max_length > config.min_length class TestPredefinedConfigs: """Test predefined dataset configurations""" def test_wikitext_config(self): """Test WikiText configuration""" config = DATASET_CONFIGS["wikitext"] assert config.name == "wikitext" assert config.subset == "wikitext-2-raw-v1" assert config.text_column == "text" def test_c4_config(self): """Test C4 configuration""" config = DATASET_CONFIGS["c4"] assert config.name == "c4" assert config.subset == "en" assert config.streaming is True def test_all_configs_valid(self): """Test that all predefined configs are valid""" for name, config in DATASET_CONFIGS.items(): assert isinstance(config, DatasetConfig) assert config.name is not None assert config.text_column is not None assert config.max_length > 0 if __name__ == "__main__": pytest.main([__file__, "-v"])