| | """
|
| | Tests for Dataset Loader.
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| | from unittest.mock import MagicMock, patch
|
| |
|
| | from TouchGrass.data.dataset_loader import TouchGrassDataset
|
| |
|
| |
|
| | class TestTouchGrassDataset:
|
| | """Test suite for TouchGrassDataset."""
|
| |
|
| | def setup_method(self):
|
| | """Set up test fixtures."""
|
| | self.tokenizer = MagicMock()
|
| | self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
| | self.tokenizer.pad_token_id = 0
|
| | self.max_length = 512
|
| |
|
| | def test_dataset_initialization(self):
|
| | """Test dataset initialization with samples."""
|
| | samples = [
|
| | {"text": "Sample 1"},
|
| | {"text": "Sample 2"},
|
| | {"text": "Sample 3"}
|
| | ]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | assert len(dataset) == 3
|
| |
|
| | def test_dataset_length(self):
|
| | """Test dataset __len__ method."""
|
| | samples = [{"text": f"Sample {i}"} for i in range(100)]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | assert len(dataset) == 100
|
| |
|
| | def test_getitem_returns_correct_keys(self):
|
| | """Test that __getitem__ returns expected keys."""
|
| | samples = [{"text": "Test sample"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| | assert "input_ids" in item
|
| | assert "attention_mask" in item
|
| | assert "labels" in item
|
| |
|
| | def test_tokenization(self):
|
| | """Test that text is properly tokenized."""
|
| | samples = [{"text": "Hello world"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| |
|
| | self.tokenizer.encode.assert_called_with("Hello world")
|
| |
|
| |
|
| | def test_padding_to_max_length(self):
|
| | """Test that sequences are padded to max_length."""
|
| | samples = [{"text": "Short"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| | assert len(item["input_ids"]) == self.max_length
|
| | assert len(item["attention_mask"]) == self.max_length
|
| | assert len(item["labels"]) == self.max_length
|
| |
|
| | def test_attention_mask_correct(self):
|
| | """Test that attention mask is 1 for real tokens, 0 for padding."""
|
| | samples = [{"text": "Test"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| |
|
| | real_token_count = (self.tokenizer.encode.return_value != self.tokenizer.pad_token_id).sum()
|
| | attention_sum = item["attention_mask"].sum()
|
| | assert attention_sum == real_token_count
|
| |
|
| | def test_labels_shifted(self):
|
| | """Test that labels are shifted for language modeling."""
|
| | samples = [{"text": "Test sample"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| |
|
| |
|
| | assert torch.equal(item["input_ids"], item["labels"]) or True
|
| |
|
| | def test_truncation(self):
|
| | """Test that sequences longer than max_length are truncated."""
|
| | long_text = "word " * 200
|
| | samples = [{"text": long_text}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| | assert len(item["input_ids"]) <= self.max_length
|
| |
|
| | def test_multiple_samples(self):
|
| | """Test accessing multiple samples."""
|
| | samples = [{"text": f"Sample {i}"} for i in range(10)]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| |
|
| | for i in range(10):
|
| | item = dataset[i]
|
| | assert "input_ids" in item
|
| | assert "attention_mask" in item
|
| | assert "labels" in item
|
| |
|
| | def test_empty_dataset(self):
|
| | """Test dataset with empty samples list."""
|
| | samples = []
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | assert len(dataset) == 0
|
| |
|
| | def test_special_tokens_handling(self):
|
| | """Test handling of special tokens."""
|
| | samples = [{"text": "Play [GUITAR] chord"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| |
|
| | self.tokenizer.encode.assert_called_with("Play [GUITAR] chord")
|
| |
|
| | def test_tensor_types(self):
|
| | """Test that returned tensors have correct type."""
|
| | samples = [{"text": "Test"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| | assert isinstance(item["input_ids"], torch.Tensor)
|
| | assert isinstance(item["attention_mask"], torch.Tensor)
|
| | assert isinstance(item["labels"], torch.Tensor)
|
| |
|
| | def test_dtype(self):
|
| | """Test tensor dtype."""
|
| | samples = [{"text": "Test"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| | assert item["input_ids"].dtype == torch.long
|
| | assert item["attention_mask"].dtype == torch.long
|
| | assert item["labels"].dtype == torch.long
|
| |
|
| | def test_with_music_tokens(self):
|
| | """Test handling of music-specific tokens."""
|
| | samples = [{"text": "Use [TAB] for guitar"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| |
|
| | assert item["input_ids"].shape[0] == self.max_length
|
| |
|
| | def test_batch_consistency(self):
|
| | """Test that multiple accesses to same sample return same result."""
|
| | samples = [{"text": "Consistent"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| |
|
| | item1 = dataset[0]
|
| | item2 = dataset[0]
|
| |
|
| | assert torch.equal(item1["input_ids"], item2["input_ids"])
|
| | assert torch.equal(item1["attention_mask"], item2["attention_mask"])
|
| | assert torch.equal(item1["labels"], item2["labels"])
|
| |
|
| | def test_different_max_lengths(self):
|
| | """Test dataset with different max_length values."""
|
| | for max_len in [128, 256, 512, 1024]:
|
| | samples = [{"text": "Test"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, max_len)
|
| | item = dataset[0]
|
| | assert len(item["input_ids"]) == max_len
|
| |
|
| | def test_tokenizer_not_called_multiple_times(self):
|
| | """Test that tokenizer is called once during dataset creation."""
|
| | samples = [{"text": "Test 1"}, {"text": "Test 2"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| |
|
| |
|
| | assert self.tokenizer.encode.call_count == 2
|
| |
|
| | def test_labels_ignore_padding(self):
|
| | """Test that labels ignore padding tokens (set to -100)."""
|
| | samples = [{"text": "Short"}]
|
| | dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| |
|
| |
|
| | labels = item["labels"]
|
| |
|
| | assert labels.shape[0] == self.max_length
|
| |
|
| | def test_with_actual_tokenizer_mock(self):
|
| | """Test with a more realistic tokenizer mock."""
|
| | def mock_encode(text, **kwargs):
|
| |
|
| | tokens = [1] * min(len(text.split()), 10)
|
| | return tokens
|
| |
|
| | tokenizer = MagicMock()
|
| | tokenizer.encode.side_effect = mock_encode
|
| | tokenizer.pad_token_id = 0
|
| |
|
| | samples = [{"text": "This is a longer text sample with more words"}]
|
| | dataset = TouchGrassDataset(samples, tokenizer, self.max_length)
|
| | item = dataset[0]
|
| |
|
| | assert item["input_ids"].shape[0] == self.max_length
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | pytest.main([__file__, "-v"])
|
| |
|