""" Tests for Ear Training Module. """ import pytest import torch from TouchGrass.models.ear_training_module import EarTrainingModule class TestEarTrainingModule: """Test suite for EarTrainingModule.""" def setup_method(self): """Set up test fixtures.""" self.d_model = 768 self.batch_size = 4 self.module = EarTrainingModule(d_model=self.d_model) def test_module_initialization(self): """Test that module initializes correctly.""" assert isinstance(self.module.interval_embed, torch.nn.Embedding) assert isinstance(self.module.interval_classifier, torch.nn.Linear) assert isinstance(self.module.solfege_embed, torch.nn.Embedding) assert isinstance(self.module.solfege_generator, torch.nn.LSTM) assert isinstance(self.module.quiz_lstm, torch.nn.LSTM) assert isinstance(self.module.quiz_head, torch.nn.Linear) def test_forward_pass(self): """Test forward pass with dummy inputs.""" seq_len = 10 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) interval_ids = torch.randint(0, 12, (self.batch_size, seq_len)) # 12 intervals output = self.module(hidden_states, interval_ids) assert "interval_logits" in output assert "solfege" in output assert "quiz_questions" in output assert output["interval_logits"].shape == (self.batch_size, seq_len, 12) assert output["solfege"].shape[0] == self.batch_size assert output["solfege"].shape[1] == seq_len assert output["quiz_questions"].shape[0] == self.batch_size assert output["quiz_questions"].shape[1] == seq_len def test_get_interval_name(self): """Test interval name retrieval.""" assert self.module.get_interval_name(0) == "P1" # Perfect unison assert self.module.get_interval_name(2) == "M2" # Major 2nd assert self.module.get_interval_name(4) == "M3" # Major 3rd assert self.module.get_interval_name(7) == "P5" # Perfect 5th assert self.module.get_interval_name(12) == "P8" # Perfect octave def test_get_song_reference(self): """Test song reference retrieval for intervals.""" # Perfect 5th - Star Wars p5_refs = self.module.get_song_reference("P5") assert "Star Wars" in p5_refs or "star wars" in p5_refs.lower() # Minor 2nd - Jaws m2_refs = self.module.get_song_reference("m2") assert "Jaws" in m2_refs or "jaws" in m2_refs.lower() # Major 3rd - When the Saints M3_refs = self.module.get_song_reference("M3") assert "Saints" in M3_refs or "saints" in M3_refs.lower() def test_generate_solfege_exercise(self): """Test solfege exercise generation.""" exercise = self.module.generate_solfege_exercise(difficulty="beginner", key="C") assert "exercise" in exercise or "notes" in exercise assert "key" in exercise or "C" in str(exercise) def test_generate_interval_quiz(self): """Test interval quiz generation.""" quiz = self.module.generate_interval_quiz(num_questions=5, difficulty="medium") assert "questions" in quiz assert len(quiz["questions"]) == 5 def test_describe_interval(self): """Test interval description with song reference.""" description = self.module.describe_interval(7) # Perfect 5th assert "7 semitones" in description or "perfect fifth" in description.lower() assert "Star Wars" in description or "star wars" in description.lower() def test_get_solfege_syllables(self): """Test solfege syllable retrieval.""" syllables = self.module.get_solfege_syllables(key="C", mode="major") expected = ["Do", "Re", "Mi", "Fa", "So", "La", "Ti", "Do"] assert syllables == expected def test_get_solfege_syllables_minor(self): """Test solfege syllables for minor mode.""" syllables = self.module.get_solfege_syllables(key="A", mode="minor") # Minor solfege: Do Re Me Fa Se Le Te Do (or variations) assert "Do" in syllables assert len(syllables) >= 7 def test_interval_to_name(self): """Test converting semitone count to interval name.""" assert self.module.interval_to_name(0) == "P1" assert self.module.interval_to_name(1) == "m2" assert self.module.interval_to_name(2) == "M2" assert self.module.interval_to_name(3) == "m3" assert self.module.interval_to_name(4) == "M3" assert self.module.interval_to_name(5) == "P4" assert self.module.interval_to_name(6) == "TT" # Tritone assert self.module.interval_to_name(7) == "P5" assert self.module.interval_to_name(11) == "M7" assert self.module.interval_to_name(12) == "P8" def test_name_to_interval(self): """Test converting interval name to semitone count.""" assert self.module.name_to_interval("P1") == 0 assert self.module.name_to_interval("m2") == 1 assert self.module.name_to_interval("M2") == 2 assert self.module.name_to_interval("M3") == 4 assert self.module.name_to_interval("P4") == 5 assert self.module.name_to_interval("P5") == 7 assert self.module.name_to_interval("P8") == 12 def test_quiz_question_format(self): """Test that quiz questions are properly formatted.""" quiz = self.module.generate_interval_quiz(num_questions=3, difficulty="easy") for question in quiz["questions"]: assert "question" in question assert "answer" in question assert "options" in question or isinstance(question["answer"], (str, int)) def test_solfege_output_length(self): """Test solfege output has correct sequence length.""" seq_len = 10 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) interval_ids = torch.randint(0, 12, (self.batch_size, seq_len)) output = self.module(hidden_states, interval_ids) solfege_seq_len = output["solfege"].shape[1] assert solfege_seq_len == seq_len def test_different_batch_sizes(self): """Test forward pass with different batch sizes.""" for batch_size in [1, 2, 8]: seq_len = 10 hidden_states = torch.randn(batch_size, seq_len, self.d_model) interval_ids = torch.randint(0, 12, (batch_size, seq_len)) output = self.module(hidden_states, interval_ids) assert output["interval_logits"].shape[0] == batch_size def test_gradient_flow(self): """Test that gradients flow through the module.""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True) interval_ids = torch.randint(0, 12, (self.batch_size, seq_len)) output = self.module(hidden_states, interval_ids) loss = output["interval_logits"].sum() + output["solfege"].sum() loss.backward() assert hidden_states.grad is not None assert self.module.interval_embed.weight.grad is not None def test_interval_classifier_output(self): """Test interval classifier produces logits for all intervals.""" seq_len = 1 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) interval_ids = torch.randint(0, 12, (self.batch_size, seq_len)) output = self.module(hidden_states, interval_ids) logits = output["interval_logits"] # Should have logits for 12 intervals (0-11 semitones) assert logits.shape[-1] == 12 def test_quiz_head_output(self): """Test quiz head produces appropriate output.""" seq_len = 1 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) interval_ids = torch.randint(0, 12, (self.batch_size, seq_len)) output = self.module(hidden_states, interval_ids) quiz_output = output["quiz_questions"] # Quiz output should have some dimension for question generation assert quiz_output.shape[0] == self.batch_size assert quiz_output.shape[1] == seq_len def test_song_reference_coverage(self): """Test that common intervals have song references.""" common_intervals = [0, 2, 4, 5, 7, 9, 12] # P1, M2, M3, P4, P5, M6, P8 for interval in common_intervals: name = self.module.interval_to_name(interval) refs = self.module.get_song_reference(name) assert len(refs) > 0, f"No song reference for interval {name}" def test_musical_accuracy(self): """Test musical accuracy of interval calculations.""" # Test all intervals from 0 to 12 for semitones in range(13): name = self.module.interval_to_name(semitones) converted_back = self.module.name_to_interval(name) assert converted_back == semitones, f"Round-trip failed for {semitones} ({name})" if __name__ == "__main__": pytest.main([__file__, "-v"])