| | """
|
| | Tests for Ear Training Module.
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| |
|
| | from TouchGrass.models.ear_training_module import EarTrainingModule
|
| |
|
| |
|
| | class TestEarTrainingModule:
|
| | """Test suite for EarTrainingModule."""
|
| |
|
| | def setup_method(self):
|
| | """Set up test fixtures."""
|
| | self.d_model = 768
|
| | self.batch_size = 4
|
| | self.module = EarTrainingModule(d_model=self.d_model)
|
| |
|
| | def test_module_initialization(self):
|
| | """Test that module initializes correctly."""
|
| | assert isinstance(self.module.interval_embed, torch.nn.Embedding)
|
| | assert isinstance(self.module.interval_classifier, torch.nn.Linear)
|
| | assert isinstance(self.module.solfege_embed, torch.nn.Embedding)
|
| | assert isinstance(self.module.solfege_generator, torch.nn.LSTM)
|
| | assert isinstance(self.module.quiz_lstm, torch.nn.LSTM)
|
| | assert isinstance(self.module.quiz_head, torch.nn.Linear)
|
| |
|
| | def test_forward_pass(self):
|
| | """Test forward pass with dummy inputs."""
|
| | seq_len = 10
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, interval_ids)
|
| |
|
| | assert "interval_logits" in output
|
| | assert "solfege" in output
|
| | assert "quiz_questions" in output
|
| | assert output["interval_logits"].shape == (self.batch_size, seq_len, 12)
|
| | assert output["solfege"].shape[0] == self.batch_size
|
| | assert output["solfege"].shape[1] == seq_len
|
| | assert output["quiz_questions"].shape[0] == self.batch_size
|
| | assert output["quiz_questions"].shape[1] == seq_len
|
| |
|
| | def test_get_interval_name(self):
|
| | """Test interval name retrieval."""
|
| | assert self.module.get_interval_name(0) == "P1"
|
| | assert self.module.get_interval_name(2) == "M2"
|
| | assert self.module.get_interval_name(4) == "M3"
|
| | assert self.module.get_interval_name(7) == "P5"
|
| | assert self.module.get_interval_name(12) == "P8"
|
| |
|
| | def test_get_song_reference(self):
|
| | """Test song reference retrieval for intervals."""
|
| |
|
| | p5_refs = self.module.get_song_reference("P5")
|
| | assert "Star Wars" in p5_refs or "star wars" in p5_refs.lower()
|
| |
|
| |
|
| | m2_refs = self.module.get_song_reference("m2")
|
| | assert "Jaws" in m2_refs or "jaws" in m2_refs.lower()
|
| |
|
| |
|
| | M3_refs = self.module.get_song_reference("M3")
|
| | assert "Saints" in M3_refs or "saints" in M3_refs.lower()
|
| |
|
| | def test_generate_solfege_exercise(self):
|
| | """Test solfege exercise generation."""
|
| | exercise = self.module.generate_solfege_exercise(difficulty="beginner", key="C")
|
| | assert "exercise" in exercise or "notes" in exercise
|
| | assert "key" in exercise or "C" in str(exercise)
|
| |
|
| | def test_generate_interval_quiz(self):
|
| | """Test interval quiz generation."""
|
| | quiz = self.module.generate_interval_quiz(num_questions=5, difficulty="medium")
|
| | assert "questions" in quiz
|
| | assert len(quiz["questions"]) == 5
|
| |
|
| | def test_describe_interval(self):
|
| | """Test interval description with song reference."""
|
| | description = self.module.describe_interval(7)
|
| | assert "7 semitones" in description or "perfect fifth" in description.lower()
|
| | assert "Star Wars" in description or "star wars" in description.lower()
|
| |
|
| | def test_get_solfege_syllables(self):
|
| | """Test solfege syllable retrieval."""
|
| | syllables = self.module.get_solfege_syllables(key="C", mode="major")
|
| | expected = ["Do", "Re", "Mi", "Fa", "So", "La", "Ti", "Do"]
|
| | assert syllables == expected
|
| |
|
| | def test_get_solfege_syllables_minor(self):
|
| | """Test solfege syllables for minor mode."""
|
| | syllables = self.module.get_solfege_syllables(key="A", mode="minor")
|
| |
|
| | assert "Do" in syllables
|
| | assert len(syllables) >= 7
|
| |
|
| | def test_interval_to_name(self):
|
| | """Test converting semitone count to interval name."""
|
| | assert self.module.interval_to_name(0) == "P1"
|
| | assert self.module.interval_to_name(1) == "m2"
|
| | assert self.module.interval_to_name(2) == "M2"
|
| | assert self.module.interval_to_name(3) == "m3"
|
| | assert self.module.interval_to_name(4) == "M3"
|
| | assert self.module.interval_to_name(5) == "P4"
|
| | assert self.module.interval_to_name(6) == "TT"
|
| | assert self.module.interval_to_name(7) == "P5"
|
| | assert self.module.interval_to_name(11) == "M7"
|
| | assert self.module.interval_to_name(12) == "P8"
|
| |
|
| | def test_name_to_interval(self):
|
| | """Test converting interval name to semitone count."""
|
| | assert self.module.name_to_interval("P1") == 0
|
| | assert self.module.name_to_interval("m2") == 1
|
| | assert self.module.name_to_interval("M2") == 2
|
| | assert self.module.name_to_interval("M3") == 4
|
| | assert self.module.name_to_interval("P4") == 5
|
| | assert self.module.name_to_interval("P5") == 7
|
| | assert self.module.name_to_interval("P8") == 12
|
| |
|
| | def test_quiz_question_format(self):
|
| | """Test that quiz questions are properly formatted."""
|
| | quiz = self.module.generate_interval_quiz(num_questions=3, difficulty="easy")
|
| | for question in quiz["questions"]:
|
| | assert "question" in question
|
| | assert "answer" in question
|
| | assert "options" in question or isinstance(question["answer"], (str, int))
|
| |
|
| | def test_solfege_output_length(self):
|
| | """Test solfege output has correct sequence length."""
|
| | seq_len = 10
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, interval_ids)
|
| | solfege_seq_len = output["solfege"].shape[1]
|
| | assert solfege_seq_len == seq_len
|
| |
|
| | def test_different_batch_sizes(self):
|
| | """Test forward pass with different batch sizes."""
|
| | for batch_size in [1, 2, 8]:
|
| | seq_len = 10
|
| | hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
| | interval_ids = torch.randint(0, 12, (batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, interval_ids)
|
| | assert output["interval_logits"].shape[0] == batch_size
|
| |
|
| | def test_gradient_flow(self):
|
| | """Test that gradients flow through the module."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
|
| | interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, interval_ids)
|
| | loss = output["interval_logits"].sum() + output["solfege"].sum()
|
| | loss.backward()
|
| |
|
| | assert hidden_states.grad is not None
|
| | assert self.module.interval_embed.weight.grad is not None
|
| |
|
| | def test_interval_classifier_output(self):
|
| | """Test interval classifier produces logits for all intervals."""
|
| | seq_len = 1
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, interval_ids)
|
| | logits = output["interval_logits"]
|
| |
|
| |
|
| | assert logits.shape[-1] == 12
|
| |
|
| | def test_quiz_head_output(self):
|
| | """Test quiz head produces appropriate output."""
|
| | seq_len = 1
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, interval_ids)
|
| | quiz_output = output["quiz_questions"]
|
| |
|
| |
|
| | assert quiz_output.shape[0] == self.batch_size
|
| | assert quiz_output.shape[1] == seq_len
|
| |
|
| | def test_song_reference_coverage(self):
|
| | """Test that common intervals have song references."""
|
| | common_intervals = [0, 2, 4, 5, 7, 9, 12]
|
| | for interval in common_intervals:
|
| | name = self.module.interval_to_name(interval)
|
| | refs = self.module.get_song_reference(name)
|
| | assert len(refs) > 0, f"No song reference for interval {name}"
|
| |
|
| | def test_musical_accuracy(self):
|
| | """Test musical accuracy of interval calculations."""
|
| |
|
| | for semitones in range(13):
|
| | name = self.module.interval_to_name(semitones)
|
| | converted_back = self.module.name_to_interval(name)
|
| | assert converted_back == semitones, f"Round-trip failed for {semitones} ({name})"
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | pytest.main([__file__, "-v"])
|
| |
|