| """
|
| Tests for Music QA Dataset Generator.
|
| """
|
|
|
| import pytest
|
| from unittest.mock import MagicMock, patch
|
|
|
| from TouchGrass.data.music_qa_generator import MusicQAGenerator, MUSIC_QA_TEMPLATES
|
|
|
|
|
| class TestMusicQAGenerator:
|
| """Test suite for MusicQAGenerator."""
|
|
|
| def setup_method(self):
|
| """Set up test fixtures."""
|
| self.generator = MusicQAGenerator()
|
|
|
| def test_generator_initialization(self):
|
| """Test that generator initializes correctly."""
|
| assert hasattr(self.generator, "templates")
|
| assert hasattr(self.generator, "generate_dataset")
|
| assert hasattr(self.generator, "save_dataset")
|
| assert isinstance(self.generator.templates, dict)
|
|
|
| def test_templates_structure(self):
|
| """Test that templates have correct structure."""
|
| expected_categories = [
|
| "guitar", "piano", "drums", "vocals", "theory",
|
| "ear_training", "songwriting", "production", "frustration", "general"
|
| ]
|
| for category in expected_categories:
|
| assert category in self.generator.templates
|
| assert isinstance(self.generator.templates[category], list)
|
| assert len(self.generator.templates[category]) > 0
|
|
|
| def test_generate_dataset_default(self):
|
| """Test dataset generation with default parameters."""
|
| dataset = self.generator.generate_dataset(num_samples=100)
|
| assert isinstance(dataset, list)
|
| assert len(dataset) == 100
|
|
|
| def test_generate_dataset_categories(self):
|
| """Test that generated samples have required categories."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| categories_seen = set()
|
| for sample in dataset:
|
| assert "category" in sample
|
| assert "messages" in sample
|
| assert isinstance(sample["messages"], list)
|
| categories_seen.add(sample["category"])
|
|
|
| assert len(categories_seen) >= 3
|
|
|
| def test_message_structure(self):
|
| """Test that messages have correct role structure."""
|
| dataset = self.generator.generate_dataset(num_samples=10)
|
| for sample in dataset:
|
| messages = sample["messages"]
|
|
|
| assert len(messages) >= 3
|
| for msg in messages:
|
| assert "role" in msg
|
| assert "content" in msg
|
| assert msg["role"] in ["system", "user", "assistant"]
|
|
|
| def test_system_messages_present(self):
|
| """Test that system messages are present."""
|
| dataset = self.generator.generate_dataset(num_samples=20)
|
| for sample in dataset:
|
| roles = [msg["role"] for msg in sample["messages"]]
|
| assert "system" in roles
|
|
|
| def test_assistant_responses_present(self):
|
| """Test that assistant responses are present."""
|
| dataset = self.generator.generate_dataset(num_samples=20)
|
| for sample in dataset:
|
| roles = [msg["role"] for msg in sample["messages"]]
|
| assert "assistant" in roles
|
|
|
| def test_content_not_empty(self):
|
| """Test that message content is not empty."""
|
| dataset = self.generator.generate_dataset(num_samples=30)
|
| for sample in dataset:
|
| for msg in sample["messages"]:
|
| assert len(msg["content"].strip()) > 0
|
|
|
| def test_generate_with_custom_templates(self):
|
| """Test dataset generation with custom templates."""
|
| custom_templates = {
|
| "test_category": [
|
| {
|
| "system": "You are a test assistant.",
|
| "user": "Test question: {query}",
|
| "assistant": "Test answer: {answer}"
|
| }
|
| ]
|
| }
|
| generator = MusicQAGenerator(templates=custom_templates)
|
| dataset = generator.generate_dataset(num_samples=5)
|
| assert len(dataset) == 5
|
| assert all(s["category"] == "test_category" for s in dataset)
|
|
|
| def test_save_dataset_jsonl(self, tmp_path):
|
| """Test saving dataset in JSONL format."""
|
| dataset = self.generator.generate_dataset(num_samples=10)
|
| output_path = tmp_path / "test_dataset.jsonl"
|
| self.generator.save_dataset(dataset, str(output_path), format="jsonl")
|
| assert output_path.exists()
|
|
|
|
|
| with open(output_path, 'r', encoding='utf-8') as f:
|
| lines = f.readlines()
|
| assert len(lines) == 10
|
| import json
|
| for line in lines:
|
| sample = json.loads(line)
|
| assert "category" in sample
|
| assert "messages" in sample
|
|
|
| def test_save_dataset_json(self, tmp_path):
|
| """Test saving dataset in JSON format."""
|
| dataset = self.generator.generate_dataset(num_samples=10)
|
| output_path = tmp_path / "test_dataset.json"
|
| self.generator.save_dataset(dataset, str(output_path), format="json")
|
| assert output_path.exists()
|
|
|
|
|
| with open(output_path, 'r', encoding='utf-8') as f:
|
| import json
|
| data = json.load(f)
|
| assert isinstance(data, list)
|
| assert len(data) == 10
|
|
|
| def test_generate_different_sample_counts(self):
|
| """Test generating different numbers of samples."""
|
| for num in [1, 10, 50, 100]:
|
| dataset = self.generator.generate_dataset(num_samples=num)
|
| assert len(dataset) == num
|
|
|
| def test_category_distribution(self):
|
| """Test that category distribution is reasonable."""
|
| dataset = self.generator.generate_dataset(num_samples=200)
|
| categories = [s["category"] for s in dataset]
|
| unique_categories = set(categories)
|
|
|
| assert len(unique_categories) >= 5
|
|
|
| def test_template_variable_substitution(self):
|
| """Test that template variables are properly substituted."""
|
| dataset = self.generator.generate_dataset(num_samples=5)
|
| for sample in dataset:
|
| for msg in sample["messages"]:
|
| content = msg["content"]
|
|
|
|
|
|
|
| assert len(content) > 0
|
|
|
| def test_music_domain_coverage(self):
|
| """Test that all music domains are covered."""
|
| domains = ["guitar", "piano", "drums", "vocals", "theory", "production"]
|
| dataset = self.generator.generate_dataset(num_samples=100)
|
| categories = set(s["category"] for s in dataset)
|
|
|
| domain_coverage = sum(1 for d in domains if d in categories)
|
| assert domain_coverage >= 4
|
|
|
| def test_frustration_responses(self):
|
| """Test that frustration responses are generated."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| frustration_samples = [s for s in dataset if s["category"] == "frustration"]
|
| assert len(frustration_samples) > 0
|
| for sample in frustration_samples:
|
|
|
| content = str(sample["messages"]).lower()
|
| assert any(word in content for word in ["don't worry", "break", "practice", "time", "patience"])
|
|
|
| def test_ear_training_content(self):
|
| """Test ear training specific content."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| ear_training_samples = [s for s in dataset if s["category"] == "ear_training"]
|
| assert len(ear_training_samples) > 0
|
| for sample in ear_training_samples:
|
| content = str(sample["messages"]).lower()
|
|
|
| assert any(word in content for word in ["interval", "note", "pitch", "listen", "hear"])
|
|
|
| def test_songwriting_content(self):
|
| """Test songwriting specific content."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| songwriting_samples = [s for s in dataset if s["category"] == "songwriting"]
|
| assert len(songwriting_samples) > 0
|
| for sample in songwriting_samples:
|
| content = str(sample["messages"]).lower()
|
|
|
| assert any(word in content for word in ["chord", "lyric", "progression", "hook", "song"])
|
|
|
| def test_production_content(self):
|
| """Test music production specific content."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| production_samples = [s for s in dataset if s["category"] == "production"]
|
| assert len(production_samples) > 0
|
| for sample in production_samples:
|
| content = str(sample["messages"]).lower()
|
|
|
| assert any(word in content for word in ["eq", "mix", "compress", "volume", "frequency"])
|
|
|
| def test_theory_content(self):
|
| """Test music theory specific content."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| theory_samples = [s for s in dataset if s["category"] == "theory"]
|
| assert len(theory_samples) > 0
|
| for sample in theory_samples:
|
| content = str(sample["messages"]).lower()
|
|
|
| assert any(word in content for word in ["scale", "chord", "interval", "key", "note"])
|
|
|
| def test_guitar_content(self):
|
| """Test guitar specific content."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| guitar_samples = [s for s in dataset if s["category"] == "guitar"]
|
| assert len(guitar_samples) > 0
|
| for sample in guitar_samples:
|
| content = str(sample["messages"]).lower()
|
|
|
| assert any(word in content for word in ["fret", "string", "tab", "chord", "guitar"])
|
|
|
| def test_piano_content(self):
|
| """Test piano specific content."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| piano_samples = [s for s in dataset if s["category"] == "piano"]
|
| assert len(piano_samples) > 0
|
| for sample in piano_samples:
|
| content = str(sample["messages"]).lower()
|
|
|
| assert any(word in content for word in ["key", "hand", "pedal", "piano", "octave"])
|
|
|
| def test_drums_content(self):
|
| """Test drums specific content."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| drums_samples = [s for s in dataset if s["category"] == "drums"]
|
| assert len(drums_samples) > 0
|
| for sample in drums_samples:
|
| content = str(sample["messages"]).lower()
|
|
|
| assert any(word in content for word in ["beat", "fill", "kit", "drum", "cymbal"])
|
|
|
| def test_vocals_content(self):
|
| """Test vocals specific content."""
|
| dataset = self.generator.generate_dataset(num_samples=50)
|
| vocals_samples = [s for s in dataset if s["category"] == "vocals"]
|
| assert len(vocals_samples) > 0
|
| for sample in vocals_samples:
|
| content = str(sample["messages"]).lower()
|
|
|
| assert any(word in content for word in ["voice", "range", "breath", "vocal", "sing"])
|
|
|
| def test_reproducibility_with_seed(self):
|
| """Test that using a seed produces reproducible results."""
|
| generator1 = MusicQAGenerator(seed=42)
|
| dataset1 = generator1.generate_dataset(num_samples=50)
|
|
|
| generator2 = MusicQAGenerator(seed=42)
|
| dataset2 = generator2.generate_dataset(num_samples=50)
|
|
|
|
|
| assert dataset1 == dataset2
|
|
|
| def test_different_seeds_produce_different_results(self):
|
| """Test that different seeds produce different datasets."""
|
| generator1 = MusicQAGenerator(seed=42)
|
| dataset1 = generator1.generate_dataset(num_samples=50)
|
|
|
| generator2 = MusicQAGenerator(seed=123)
|
| dataset2 = generator2.generate_dataset(num_samples=50)
|
|
|
|
|
| assert dataset1 != dataset2
|
|
|
| def test_large_dataset_generation(self):
|
| """Test generating a larger dataset."""
|
| dataset = self.generator.generate_dataset(num_samples=1000)
|
| assert len(dataset) == 1000
|
|
|
| categories = [s["category"] for s in dataset]
|
| unique_cats = set(categories)
|
| assert len(unique_cats) >= 8
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"])
|
|
|