TouchGrass-7b / tests /test_music_qa_generator.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Tests for Music QA Dataset Generator.
"""
import pytest
from unittest.mock import MagicMock, patch
from TouchGrass.data.music_qa_generator import MusicQAGenerator, MUSIC_QA_TEMPLATES
class TestMusicQAGenerator:
"""Test suite for MusicQAGenerator."""
def setup_method(self):
"""Set up test fixtures."""
self.generator = MusicQAGenerator()
def test_generator_initialization(self):
"""Test that generator initializes correctly."""
assert hasattr(self.generator, "templates")
assert hasattr(self.generator, "generate_dataset")
assert hasattr(self.generator, "save_dataset")
assert isinstance(self.generator.templates, dict)
def test_templates_structure(self):
"""Test that templates have correct structure."""
expected_categories = [
"guitar", "piano", "drums", "vocals", "theory",
"ear_training", "songwriting", "production", "frustration", "general"
]
for category in expected_categories:
assert category in self.generator.templates
assert isinstance(self.generator.templates[category], list)
assert len(self.generator.templates[category]) > 0
def test_generate_dataset_default(self):
"""Test dataset generation with default parameters."""
dataset = self.generator.generate_dataset(num_samples=100)
assert isinstance(dataset, list)
assert len(dataset) == 100
def test_generate_dataset_categories(self):
"""Test that generated samples have required categories."""
dataset = self.generator.generate_dataset(num_samples=50)
categories_seen = set()
for sample in dataset:
assert "category" in sample
assert "messages" in sample
assert isinstance(sample["messages"], list)
categories_seen.add(sample["category"])
# Should have at least some variety in categories
assert len(categories_seen) >= 3
def test_message_structure(self):
"""Test that messages have correct role structure."""
dataset = self.generator.generate_dataset(num_samples=10)
for sample in dataset:
messages = sample["messages"]
# Should have at least 3 messages (system, user, assistant)
assert len(messages) >= 3
for msg in messages:
assert "role" in msg
assert "content" in msg
assert msg["role"] in ["system", "user", "assistant"]
def test_system_messages_present(self):
"""Test that system messages are present."""
dataset = self.generator.generate_dataset(num_samples=20)
for sample in dataset:
roles = [msg["role"] for msg in sample["messages"]]
assert "system" in roles
def test_assistant_responses_present(self):
"""Test that assistant responses are present."""
dataset = self.generator.generate_dataset(num_samples=20)
for sample in dataset:
roles = [msg["role"] for msg in sample["messages"]]
assert "assistant" in roles
def test_content_not_empty(self):
"""Test that message content is not empty."""
dataset = self.generator.generate_dataset(num_samples=30)
for sample in dataset:
for msg in sample["messages"]:
assert len(msg["content"].strip()) > 0
def test_generate_with_custom_templates(self):
"""Test dataset generation with custom templates."""
custom_templates = {
"test_category": [
{
"system": "You are a test assistant.",
"user": "Test question: {query}",
"assistant": "Test answer: {answer}"
}
]
}
generator = MusicQAGenerator(templates=custom_templates)
dataset = generator.generate_dataset(num_samples=5)
assert len(dataset) == 5
assert all(s["category"] == "test_category" for s in dataset)
def test_save_dataset_jsonl(self, tmp_path):
"""Test saving dataset in JSONL format."""
dataset = self.generator.generate_dataset(num_samples=10)
output_path = tmp_path / "test_dataset.jsonl"
self.generator.save_dataset(dataset, str(output_path), format="jsonl")
assert output_path.exists()
# Verify file content
with open(output_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
assert len(lines) == 10
import json
for line in lines:
sample = json.loads(line)
assert "category" in sample
assert "messages" in sample
def test_save_dataset_json(self, tmp_path):
"""Test saving dataset in JSON format."""
dataset = self.generator.generate_dataset(num_samples=10)
output_path = tmp_path / "test_dataset.json"
self.generator.save_dataset(dataset, str(output_path), format="json")
assert output_path.exists()
# Verify file content
with open(output_path, 'r', encoding='utf-8') as f:
import json
data = json.load(f)
assert isinstance(data, list)
assert len(data) == 10
def test_generate_different_sample_counts(self):
"""Test generating different numbers of samples."""
for num in [1, 10, 50, 100]:
dataset = self.generator.generate_dataset(num_samples=num)
assert len(dataset) == num
def test_category_distribution(self):
"""Test that category distribution is reasonable."""
dataset = self.generator.generate_dataset(num_samples=200)
categories = [s["category"] for s in dataset]
unique_categories = set(categories)
# Should have multiple categories represented
assert len(unique_categories) >= 5
def test_template_variable_substitution(self):
"""Test that template variables are properly substituted."""
dataset = self.generator.generate_dataset(num_samples=5)
for sample in dataset:
for msg in sample["messages"]:
content = msg["content"]
# Should not contain unsubstituted variables like {query}, {answer}
# (unless they're intentionally left in some templates)
# At minimum, content should be non-empty
assert len(content) > 0
def test_music_domain_coverage(self):
"""Test that all music domains are covered."""
domains = ["guitar", "piano", "drums", "vocals", "theory", "production"]
dataset = self.generator.generate_dataset(num_samples=100)
categories = set(s["category"] for s in dataset)
# At least 4 of 6 domains should be represented in 100 samples
domain_coverage = sum(1 for d in domains if d in categories)
assert domain_coverage >= 4
def test_frustration_responses(self):
"""Test that frustration responses are generated."""
dataset = self.generator.generate_dataset(num_samples=50)
frustration_samples = [s for s in dataset if s["category"] == "frustration"]
assert len(frustration_samples) > 0
for sample in frustration_samples:
# Frustration samples should have encouraging content
content = str(sample["messages"]).lower()
assert any(word in content for word in ["don't worry", "break", "practice", "time", "patience"])
def test_ear_training_content(self):
"""Test ear training specific content."""
dataset = self.generator.generate_dataset(num_samples=50)
ear_training_samples = [s for s in dataset if s["category"] == "ear_training"]
assert len(ear_training_samples) > 0
for sample in ear_training_samples:
content = str(sample["messages"]).lower()
# Should mention intervals, notes, or listening
assert any(word in content for word in ["interval", "note", "pitch", "listen", "hear"])
def test_songwriting_content(self):
"""Test songwriting specific content."""
dataset = self.generator.generate_dataset(num_samples=50)
songwriting_samples = [s for s in dataset if s["category"] == "songwriting"]
assert len(songwriting_samples) > 0
for sample in songwriting_samples:
content = str(sample["messages"]).lower()
# Should mention chords, lyrics, or structure
assert any(word in content for word in ["chord", "lyric", "progression", "hook", "song"])
def test_production_content(self):
"""Test music production specific content."""
dataset = self.generator.generate_dataset(num_samples=50)
production_samples = [s for s in dataset if s["category"] == "production"]
assert len(production_samples) > 0
for sample in production_samples:
content = str(sample["messages"]).lower()
# Should mention EQ, mixing, compression, etc.
assert any(word in content for word in ["eq", "mix", "compress", "volume", "frequency"])
def test_theory_content(self):
"""Test music theory specific content."""
dataset = self.generator.generate_dataset(num_samples=50)
theory_samples = [s for s in dataset if s["category"] == "theory"]
assert len(theory_samples) > 0
for sample in theory_samples:
content = str(sample["messages"]).lower()
# Should mention scales, chords, intervals, etc.
assert any(word in content for word in ["scale", "chord", "interval", "key", "note"])
def test_guitar_content(self):
"""Test guitar specific content."""
dataset = self.generator.generate_dataset(num_samples=50)
guitar_samples = [s for s in dataset if s["category"] == "guitar"]
assert len(guitar_samples) > 0
for sample in guitar_samples:
content = str(sample["messages"]).lower()
# Should mention frets, strings, tabs, chords, etc.
assert any(word in content for word in ["fret", "string", "tab", "chord", "guitar"])
def test_piano_content(self):
"""Test piano specific content."""
dataset = self.generator.generate_dataset(num_samples=50)
piano_samples = [s for s in dataset if s["category"] == "piano"]
assert len(piano_samples) > 0
for sample in piano_samples:
content = str(sample["messages"]).lower()
# Should mention keys, hands, pedals, etc.
assert any(word in content for word in ["key", "hand", "pedal", "piano", "octave"])
def test_drums_content(self):
"""Test drums specific content."""
dataset = self.generator.generate_dataset(num_samples=50)
drums_samples = [s for s in dataset if s["category"] == "drums"]
assert len(drums_samples) > 0
for sample in drums_samples:
content = str(sample["messages"]).lower()
# Should mention beats, fills, kit, etc.
assert any(word in content for word in ["beat", "fill", "kit", "drum", "cymbal"])
def test_vocals_content(self):
"""Test vocals specific content."""
dataset = self.generator.generate_dataset(num_samples=50)
vocals_samples = [s for s in dataset if s["category"] == "vocals"]
assert len(vocals_samples) > 0
for sample in vocals_samples:
content = str(sample["messages"]).lower()
# Should mention voice, range, breathing, etc.
assert any(word in content for word in ["voice", "range", "breath", "vocal", "sing"])
def test_reproducibility_with_seed(self):
"""Test that using a seed produces reproducible results."""
generator1 = MusicQAGenerator(seed=42)
dataset1 = generator1.generate_dataset(num_samples=50)
generator2 = MusicQAGenerator(seed=42)
dataset2 = generator2.generate_dataset(num_samples=50)
# Should be identical
assert dataset1 == dataset2
def test_different_seeds_produce_different_results(self):
"""Test that different seeds produce different datasets."""
generator1 = MusicQAGenerator(seed=42)
dataset1 = generator1.generate_dataset(num_samples=50)
generator2 = MusicQAGenerator(seed=123)
dataset2 = generator2.generate_dataset(num_samples=50)
# Should be different (very unlikely to be identical)
assert dataset1 != dataset2
def test_large_dataset_generation(self):
"""Test generating a larger dataset."""
dataset = self.generator.generate_dataset(num_samples=1000)
assert len(dataset) == 1000
# Check that we have good category distribution
categories = [s["category"] for s in dataset]
unique_cats = set(categories)
assert len(unique_cats) >= 8 # Should cover most categories
if __name__ == "__main__":
pytest.main([__file__, "-v"])