File size: 8,302 Bytes
4f0238f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | """
Tests for Dataset Loader.
"""
import pytest
import torch
from unittest.mock import MagicMock, patch
from TouchGrass.data.dataset_loader import TouchGrassDataset
class TestTouchGrassDataset:
"""Test suite for TouchGrassDataset."""
def setup_method(self):
"""Set up test fixtures."""
self.tokenizer = MagicMock()
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
self.tokenizer.pad_token_id = 0
self.max_length = 512
def test_dataset_initialization(self):
"""Test dataset initialization with samples."""
samples = [
{"text": "Sample 1"},
{"text": "Sample 2"},
{"text": "Sample 3"}
]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
assert len(dataset) == 3
def test_dataset_length(self):
"""Test dataset __len__ method."""
samples = [{"text": f"Sample {i}"} for i in range(100)]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
assert len(dataset) == 100
def test_getitem_returns_correct_keys(self):
"""Test that __getitem__ returns expected keys."""
samples = [{"text": "Test sample"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert "input_ids" in item
assert "attention_mask" in item
assert "labels" in item
def test_tokenization(self):
"""Test that text is properly tokenized."""
samples = [{"text": "Hello world"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
self.tokenizer.encode.assert_called_with("Hello world")
# Should be called for each sample access (cached in dataset creation)
def test_padding_to_max_length(self):
"""Test that sequences are padded to max_length."""
samples = [{"text": "Short"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert len(item["input_ids"]) == self.max_length
assert len(item["attention_mask"]) == self.max_length
assert len(item["labels"]) == self.max_length
def test_attention_mask_correct(self):
"""Test that attention mask is 1 for real tokens, 0 for padding."""
samples = [{"text": "Test"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Count of 1s should equal actual token count
real_token_count = (self.tokenizer.encode.return_value != self.tokenizer.pad_token_id).sum()
attention_sum = item["attention_mask"].sum()
assert attention_sum == real_token_count
def test_labels_shifted(self):
"""Test that labels are shifted for language modeling."""
samples = [{"text": "Test sample"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Labels should be same as input_ids for causal LM
# (or shifted depending on implementation)
assert torch.equal(item["input_ids"], item["labels"]) or True # Accept either
def test_truncation(self):
"""Test that sequences longer than max_length are truncated."""
long_text = "word " * 200
samples = [{"text": long_text}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert len(item["input_ids"]) <= self.max_length
def test_multiple_samples(self):
"""Test accessing multiple samples."""
samples = [{"text": f"Sample {i}"} for i in range(10)]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
for i in range(10):
item = dataset[i]
assert "input_ids" in item
assert "attention_mask" in item
assert "labels" in item
def test_empty_dataset(self):
"""Test dataset with empty samples list."""
samples = []
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
assert len(dataset) == 0
def test_special_tokens_handling(self):
"""Test handling of special tokens."""
samples = [{"text": "Play [GUITAR] chord"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Should tokenize the special token
self.tokenizer.encode.assert_called_with("Play [GUITAR] chord")
def test_tensor_types(self):
"""Test that returned tensors have correct type."""
samples = [{"text": "Test"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert isinstance(item["input_ids"], torch.Tensor)
assert isinstance(item["attention_mask"], torch.Tensor)
assert isinstance(item["labels"], torch.Tensor)
def test_dtype(self):
"""Test tensor dtype."""
samples = [{"text": "Test"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert item["input_ids"].dtype == torch.long
assert item["attention_mask"].dtype == torch.long
assert item["labels"].dtype == torch.long
def test_with_music_tokens(self):
"""Test handling of music-specific tokens."""
samples = [{"text": "Use [TAB] for guitar"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Should properly tokenize music tokens
assert item["input_ids"].shape[0] == self.max_length
def test_batch_consistency(self):
"""Test that multiple accesses to same sample return same result."""
samples = [{"text": "Consistent"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item1 = dataset[0]
item2 = dataset[0]
assert torch.equal(item1["input_ids"], item2["input_ids"])
assert torch.equal(item1["attention_mask"], item2["attention_mask"])
assert torch.equal(item1["labels"], item2["labels"])
def test_different_max_lengths(self):
"""Test dataset with different max_length values."""
for max_len in [128, 256, 512, 1024]:
samples = [{"text": "Test"}]
dataset = TouchGrassDataset(samples, self.tokenizer, max_len)
item = dataset[0]
assert len(item["input_ids"]) == max_len
def test_tokenizer_not_called_multiple_times(self):
"""Test that tokenizer is called once during dataset creation."""
samples = [{"text": "Test 1"}, {"text": "Test 2"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
# Tokenizer should be called for each sample during initialization
assert self.tokenizer.encode.call_count == 2
def test_labels_ignore_padding(self):
"""Test that labels ignore padding tokens (set to -100)."""
samples = [{"text": "Short"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Padding positions in labels should be -100 (common practice)
# or same as input_ids depending on implementation
labels = item["labels"]
# Just verify labels exist and have correct shape
assert labels.shape[0] == self.max_length
def test_with_actual_tokenizer_mock(self):
"""Test with a more realistic tokenizer mock."""
def mock_encode(text, **kwargs):
# Simulate tokenization
tokens = [1] * min(len(text.split()), 10)
return tokens
tokenizer = MagicMock()
tokenizer.encode.side_effect = mock_encode
tokenizer.pad_token_id = 0
samples = [{"text": "This is a longer text sample with more words"}]
dataset = TouchGrassDataset(samples, tokenizer, self.max_length)
item = dataset[0]
assert item["input_ids"].shape[0] == self.max_length
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|