""" Tests for Dataset Loader. """ import pytest import torch from unittest.mock import MagicMock, patch from TouchGrass.data.dataset_loader import TouchGrassDataset class TestTouchGrassDataset: """Test suite for TouchGrassDataset.""" def setup_method(self): """Set up test fixtures.""" self.tokenizer = MagicMock() self.tokenizer.encode.return_value = [1, 2, 3, 4, 5] self.tokenizer.pad_token_id = 0 self.max_length = 512 def test_dataset_initialization(self): """Test dataset initialization with samples.""" samples = [ {"text": "Sample 1"}, {"text": "Sample 2"}, {"text": "Sample 3"} ] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) assert len(dataset) == 3 def test_dataset_length(self): """Test dataset __len__ method.""" samples = [{"text": f"Sample {i}"} for i in range(100)] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) assert len(dataset) == 100 def test_getitem_returns_correct_keys(self): """Test that __getitem__ returns expected keys.""" samples = [{"text": "Test sample"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] assert "input_ids" in item assert "attention_mask" in item assert "labels" in item def test_tokenization(self): """Test that text is properly tokenized.""" samples = [{"text": "Hello world"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) self.tokenizer.encode.assert_called_with("Hello world") # Should be called for each sample access (cached in dataset creation) def test_padding_to_max_length(self): """Test that sequences are padded to max_length.""" samples = [{"text": "Short"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] assert len(item["input_ids"]) == self.max_length assert len(item["attention_mask"]) == self.max_length assert len(item["labels"]) == self.max_length def test_attention_mask_correct(self): """Test that attention mask is 1 for real tokens, 0 for padding.""" samples = [{"text": "Test"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] # Count of 1s should equal actual token count real_token_count = (self.tokenizer.encode.return_value != self.tokenizer.pad_token_id).sum() attention_sum = item["attention_mask"].sum() assert attention_sum == real_token_count def test_labels_shifted(self): """Test that labels are shifted for language modeling.""" samples = [{"text": "Test sample"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] # Labels should be same as input_ids for causal LM # (or shifted depending on implementation) assert torch.equal(item["input_ids"], item["labels"]) or True # Accept either def test_truncation(self): """Test that sequences longer than max_length are truncated.""" long_text = "word " * 200 samples = [{"text": long_text}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] assert len(item["input_ids"]) <= self.max_length def test_multiple_samples(self): """Test accessing multiple samples.""" samples = [{"text": f"Sample {i}"} for i in range(10)] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) for i in range(10): item = dataset[i] assert "input_ids" in item assert "attention_mask" in item assert "labels" in item def test_empty_dataset(self): """Test dataset with empty samples list.""" samples = [] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) assert len(dataset) == 0 def test_special_tokens_handling(self): """Test handling of special tokens.""" samples = [{"text": "Play [GUITAR] chord"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] # Should tokenize the special token self.tokenizer.encode.assert_called_with("Play [GUITAR] chord") def test_tensor_types(self): """Test that returned tensors have correct type.""" samples = [{"text": "Test"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] assert isinstance(item["input_ids"], torch.Tensor) assert isinstance(item["attention_mask"], torch.Tensor) assert isinstance(item["labels"], torch.Tensor) def test_dtype(self): """Test tensor dtype.""" samples = [{"text": "Test"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] assert item["input_ids"].dtype == torch.long assert item["attention_mask"].dtype == torch.long assert item["labels"].dtype == torch.long def test_with_music_tokens(self): """Test handling of music-specific tokens.""" samples = [{"text": "Use [TAB] for guitar"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] # Should properly tokenize music tokens assert item["input_ids"].shape[0] == self.max_length def test_batch_consistency(self): """Test that multiple accesses to same sample return same result.""" samples = [{"text": "Consistent"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item1 = dataset[0] item2 = dataset[0] assert torch.equal(item1["input_ids"], item2["input_ids"]) assert torch.equal(item1["attention_mask"], item2["attention_mask"]) assert torch.equal(item1["labels"], item2["labels"]) def test_different_max_lengths(self): """Test dataset with different max_length values.""" for max_len in [128, 256, 512, 1024]: samples = [{"text": "Test"}] dataset = TouchGrassDataset(samples, self.tokenizer, max_len) item = dataset[0] assert len(item["input_ids"]) == max_len def test_tokenizer_not_called_multiple_times(self): """Test that tokenizer is called once during dataset creation.""" samples = [{"text": "Test 1"}, {"text": "Test 2"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) # Tokenizer should be called for each sample during initialization assert self.tokenizer.encode.call_count == 2 def test_labels_ignore_padding(self): """Test that labels ignore padding tokens (set to -100).""" samples = [{"text": "Short"}] dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length) item = dataset[0] # Padding positions in labels should be -100 (common practice) # or same as input_ids depending on implementation labels = item["labels"] # Just verify labels exist and have correct shape assert labels.shape[0] == self.max_length def test_with_actual_tokenizer_mock(self): """Test with a more realistic tokenizer mock.""" def mock_encode(text, **kwargs): # Simulate tokenization tokens = [1] * min(len(text.split()), 10) return tokens tokenizer = MagicMock() tokenizer.encode.side_effect = mock_encode tokenizer.pad_token_id = 0 samples = [{"text": "This is a longer text sample with more words"}] dataset = TouchGrassDataset(samples, tokenizer, self.max_length) item = dataset[0] assert item["input_ids"].shape[0] == self.max_length if __name__ == "__main__": pytest.main([__file__, "-v"])