| """
|
| Tests for Music Theory Engine Module.
|
| """
|
|
|
| import pytest
|
| import torch
|
|
|
| from TouchGrass.models.music_theory_module import MusicTheoryModule
|
|
|
|
|
| class TestMusicTheoryModule:
|
| """Test suite for MusicTheoryModule."""
|
|
|
| def setup_method(self):
|
| """Set up test fixtures."""
|
| self.d_model = 768
|
| self.batch_size = 4
|
| self.module = MusicTheoryModule(d_model=self.d_model)
|
|
|
| def test_module_initialization(self):
|
| """Test that module initializes correctly."""
|
| assert isinstance(self.module.note_embed, torch.nn.Embedding)
|
| assert isinstance(self.module.chord_encoder, torch.nn.Linear)
|
| assert isinstance(self.module.scale_classifier, torch.nn.Linear)
|
| assert isinstance(self.module.interval_predictor, torch.nn.Linear)
|
| assert isinstance(self.module.progression_lstm, torch.nn.LSTM)
|
|
|
| def test_forward_pass(self):
|
| """Test forward pass with dummy inputs."""
|
| seq_len = 10
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| note_indices = torch.randint(0, 12, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, note_indices)
|
|
|
| assert "chord" in output
|
| assert "scale" in output
|
| assert "interval" in output
|
| assert "progression" in output
|
| assert output["chord"].shape == (self.batch_size, seq_len, 128)
|
| assert output["scale"].shape == (self.batch_size, seq_len, 12)
|
| assert output["interval"].shape == (self.batch_size, seq_len, 12)
|
| assert output["progression"].shape == (self.batch_size, seq_len, 256)
|
|
|
| def test_get_scale_from_key_c_major(self):
|
| """Test scale generation for C major."""
|
| scale = self.module.get_scale_from_key("C", "major")
|
| expected = ["C", "D", "E", "F", "G", "A", "B"]
|
| assert scale == expected
|
|
|
| def test_get_scale_from_key_a_minor(self):
|
| """Test scale generation for A minor (natural minor)."""
|
| scale = self.module.get_scale_from_key("A", "natural_minor")
|
| expected = ["A", "B", "C", "D", "E", "F", "G"]
|
| assert scale == expected
|
|
|
| def test_get_scale_from_key_g_mixolydian(self):
|
| """Test scale generation for G mixolydian."""
|
| scale = self.module.get_scale_from_key("G", "mixolydian")
|
| expected = ["G", "A", "B", "C", "D", "E", "F"]
|
| assert scale == expected
|
|
|
| def test_detect_chord_function_triad(self):
|
| """Test chord function detection for triads."""
|
|
|
| function = self.module.detect_chord_function("C", "major", "C")
|
| assert function == "I"
|
|
|
|
|
| function = self.module.detect_chord_function("F", "major", "C")
|
| assert function == "IV"
|
|
|
|
|
| function = self.module.detect_chord_function("G", "major", "C")
|
| assert function == "V"
|
|
|
| def test_detect_chord_function_minor(self):
|
| """Test chord function detection for minor chords."""
|
|
|
| function = self.module.detect_chord_function("D", "minor", "C")
|
| assert function == "ii"
|
|
|
| def test_get_circle_of_fifths(self):
|
| """Test circle of fifths generation."""
|
| circle = self.module.get_circle_of_fifths()
|
| assert len(circle) == 12
|
|
|
| assert "C" in circle
|
|
|
| def test_get_modes(self):
|
| """Test mode names retrieval."""
|
| modes = self.module.get_modes()
|
| expected_modes = ["ionian", "dorian", "phrygian", "lydian", "mixolydian", "aeolian", "locrian"]
|
| assert modes == expected_modes
|
|
|
| def test_get_scale_for_mode(self):
|
| """Test getting scale for specific mode."""
|
| scale = self.module.get_scale_for_mode("dorian", "D")
|
|
|
| expected = ["D", "E", "F", "G", "A", "B", "C"]
|
| assert scale == expected
|
|
|
| def test_interval_to_semitones(self):
|
| """Test interval to semitone conversion."""
|
| assert self.module.interval_to_semitones("P1") == 0
|
| assert self.module.interval_to_semitones("M2") == 2
|
| assert self.module.interval_to_semitones("M3") == 4
|
| assert self.module.interval_to_semitones("P4") == 5
|
| assert self.module.interval_to_semitones("P5") == 7
|
| assert self.module.interval_to_semitones("M6") == 9
|
| assert self.module.interval_to_semitones("M7") == 11
|
| assert self.module.interval_to_semitones("P8") == 12
|
|
|
| def test_semitones_to_interval(self):
|
| """Test semitone to interval conversion."""
|
| assert self.module.semitones_to_interval(0) == "P1"
|
| assert self.module.semitones_to_interval(2) == "M2"
|
| assert self.module.semitones_to_interval(4) == "M3"
|
| assert self.module.semitones_to_interval(5) == "P4"
|
| assert self.module.semitones_to_interval(7) == "P5"
|
| assert self.module.semitones_to_interval(9) == "M6"
|
| assert self.module.semitones_to_interval(11) == "M7"
|
| assert self.module.semitones_to_interval(12) == "P8"
|
|
|
| def test_chord_construction_major(self):
|
| """Test major chord construction."""
|
| chord = self.module.construct_chord("C", "major")
|
|
|
| assert set(chord) == {"C", "E", "G"}
|
|
|
| def test_chord_construction_minor(self):
|
| """Test minor chord construction."""
|
| chord = self.module.construct_chord("A", "minor")
|
|
|
| assert set(chord) == {"A", "C", "E"}
|
|
|
| def test_chord_construction_dominant_7(self):
|
| """Test dominant 7th chord construction."""
|
| chord = self.module.construct_chord("G", "dominant7")
|
|
|
| assert set(chord) == {"G", "B", "D", "F"}
|
|
|
| def test_progression_analysis(self):
|
| """Test chord progression analysis."""
|
|
|
| progression = ["C", "F", "G", "C"]
|
| analysis = self.module.analyze_progression(progression, "C")
|
| assert len(analysis) == 4
|
| assert analysis[0] == "I"
|
| assert analysis[1] == "IV"
|
| assert analysis[2] == "V"
|
| assert analysis[3] == "I"
|
|
|
| def test_scale_degree_to_note(self):
|
| """Test converting scale degree to note."""
|
|
|
| assert self.module.scale_degree_to_note(1, "C", "major") == "C"
|
| assert self.module.scale_degree_to_note(3, "C", "major") == "E"
|
| assert self.module.scale_degree_to_note(5, "C", "major") == "G"
|
|
|
| def test_note_to_scale_degree(self):
|
| """Test converting note to scale degree."""
|
|
|
| assert self.module.note_to_scale_degree("C", "C", "major") == 1
|
| assert self.module.note_to_scale_degree("E", "C", "major") == 3
|
| assert self.module.note_to_scale_degree("G", "C", "major") == 5
|
|
|
| def test_relative_key(self):
|
| """Test relative major/minor detection."""
|
|
|
| assert self.module.get_relative_minor("C") == "A"
|
|
|
| assert self.module.get_relative_major("A") == "C"
|
|
|
| def test_parallel_key(self):
|
| """Test parallel major/minor."""
|
|
|
| assert self.module.get_parallel_minor("C") == "C"
|
|
|
| assert self.module.get_parallel_major("A") == "A"
|
|
|
| def test_forward_with_empty_sequence(self):
|
| """Test forward pass with empty sequence (edge case)."""
|
| seq_len = 0
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| note_indices = torch.randint(0, 12, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, note_indices)
|
|
|
| for key in ["chord", "scale", "interval", "progression"]:
|
| assert output[key].shape[0] == self.batch_size
|
| assert output[key].shape[1] == seq_len
|
|
|
| def test_different_batch_sizes(self):
|
| """Test forward pass with different batch sizes."""
|
| for batch_size in [1, 2, 8]:
|
| seq_len = 10
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
| note_indices = torch.randint(0, 12, (batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, note_indices)
|
| assert output["chord"].shape[0] == batch_size
|
|
|
| def test_gradient_flow(self):
|
| """Test that gradients flow through the module."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
|
| note_indices = torch.randint(0, 12, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, note_indices)
|
| loss = sum([out.sum() for out in output.values()])
|
| loss.backward()
|
|
|
| assert hidden_states.grad is not None
|
| assert self.module.note_embed.weight.grad is not None
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"])
|
|
|