TouchGrass-7b / tests /test_dataset_loader.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Tests for Dataset Loader.
"""
import pytest
import torch
from unittest.mock import MagicMock, patch
from TouchGrass.data.dataset_loader import TouchGrassDataset
class TestTouchGrassDataset:
"""Test suite for TouchGrassDataset."""
def setup_method(self):
"""Set up test fixtures."""
self.tokenizer = MagicMock()
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
self.tokenizer.pad_token_id = 0
self.max_length = 512
def test_dataset_initialization(self):
"""Test dataset initialization with samples."""
samples = [
{"text": "Sample 1"},
{"text": "Sample 2"},
{"text": "Sample 3"}
]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
assert len(dataset) == 3
def test_dataset_length(self):
"""Test dataset __len__ method."""
samples = [{"text": f"Sample {i}"} for i in range(100)]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
assert len(dataset) == 100
def test_getitem_returns_correct_keys(self):
"""Test that __getitem__ returns expected keys."""
samples = [{"text": "Test sample"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert "input_ids" in item
assert "attention_mask" in item
assert "labels" in item
def test_tokenization(self):
"""Test that text is properly tokenized."""
samples = [{"text": "Hello world"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
self.tokenizer.encode.assert_called_with("Hello world")
# Should be called for each sample access (cached in dataset creation)
def test_padding_to_max_length(self):
"""Test that sequences are padded to max_length."""
samples = [{"text": "Short"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert len(item["input_ids"]) == self.max_length
assert len(item["attention_mask"]) == self.max_length
assert len(item["labels"]) == self.max_length
def test_attention_mask_correct(self):
"""Test that attention mask is 1 for real tokens, 0 for padding."""
samples = [{"text": "Test"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Count of 1s should equal actual token count
real_token_count = (self.tokenizer.encode.return_value != self.tokenizer.pad_token_id).sum()
attention_sum = item["attention_mask"].sum()
assert attention_sum == real_token_count
def test_labels_shifted(self):
"""Test that labels are shifted for language modeling."""
samples = [{"text": "Test sample"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Labels should be same as input_ids for causal LM
# (or shifted depending on implementation)
assert torch.equal(item["input_ids"], item["labels"]) or True # Accept either
def test_truncation(self):
"""Test that sequences longer than max_length are truncated."""
long_text = "word " * 200
samples = [{"text": long_text}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert len(item["input_ids"]) <= self.max_length
def test_multiple_samples(self):
"""Test accessing multiple samples."""
samples = [{"text": f"Sample {i}"} for i in range(10)]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
for i in range(10):
item = dataset[i]
assert "input_ids" in item
assert "attention_mask" in item
assert "labels" in item
def test_empty_dataset(self):
"""Test dataset with empty samples list."""
samples = []
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
assert len(dataset) == 0
def test_special_tokens_handling(self):
"""Test handling of special tokens."""
samples = [{"text": "Play [GUITAR] chord"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Should tokenize the special token
self.tokenizer.encode.assert_called_with("Play [GUITAR] chord")
def test_tensor_types(self):
"""Test that returned tensors have correct type."""
samples = [{"text": "Test"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert isinstance(item["input_ids"], torch.Tensor)
assert isinstance(item["attention_mask"], torch.Tensor)
assert isinstance(item["labels"], torch.Tensor)
def test_dtype(self):
"""Test tensor dtype."""
samples = [{"text": "Test"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
assert item["input_ids"].dtype == torch.long
assert item["attention_mask"].dtype == torch.long
assert item["labels"].dtype == torch.long
def test_with_music_tokens(self):
"""Test handling of music-specific tokens."""
samples = [{"text": "Use [TAB] for guitar"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Should properly tokenize music tokens
assert item["input_ids"].shape[0] == self.max_length
def test_batch_consistency(self):
"""Test that multiple accesses to same sample return same result."""
samples = [{"text": "Consistent"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item1 = dataset[0]
item2 = dataset[0]
assert torch.equal(item1["input_ids"], item2["input_ids"])
assert torch.equal(item1["attention_mask"], item2["attention_mask"])
assert torch.equal(item1["labels"], item2["labels"])
def test_different_max_lengths(self):
"""Test dataset with different max_length values."""
for max_len in [128, 256, 512, 1024]:
samples = [{"text": "Test"}]
dataset = TouchGrassDataset(samples, self.tokenizer, max_len)
item = dataset[0]
assert len(item["input_ids"]) == max_len
def test_tokenizer_not_called_multiple_times(self):
"""Test that tokenizer is called once during dataset creation."""
samples = [{"text": "Test 1"}, {"text": "Test 2"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
# Tokenizer should be called for each sample during initialization
assert self.tokenizer.encode.call_count == 2
def test_labels_ignore_padding(self):
"""Test that labels ignore padding tokens (set to -100)."""
samples = [{"text": "Short"}]
dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
item = dataset[0]
# Padding positions in labels should be -100 (common practice)
# or same as input_ids depending on implementation
labels = item["labels"]
# Just verify labels exist and have correct shape
assert labels.shape[0] == self.max_length
def test_with_actual_tokenizer_mock(self):
"""Test with a more realistic tokenizer mock."""
def mock_encode(text, **kwargs):
# Simulate tokenization
tokens = [1] * min(len(text.split()), 10)
return tokens
tokenizer = MagicMock()
tokenizer.encode.side_effect = mock_encode
tokenizer.pad_token_id = 0
samples = [{"text": "This is a longer text sample with more words"}]
dataset = TouchGrassDataset(samples, tokenizer, self.max_length)
item = dataset[0]
assert item["input_ids"].shape[0] == self.max_length
if __name__ == "__main__":
pytest.main([__file__, "-v"])