TouchGrass-3b / tests /test_music_theory_module.py
Zandy-Wandy's picture
Upload 39 files
9071ef9 verified
"""
Tests for Music Theory Engine Module.
"""
import pytest
import torch
from TouchGrass.models.music_theory_module import MusicTheoryModule
class TestMusicTheoryModule:
"""Test suite for MusicTheoryModule."""
def setup_method(self):
"""Set up test fixtures."""
self.d_model = 768
self.batch_size = 4
self.module = MusicTheoryModule(d_model=self.d_model)
def test_module_initialization(self):
"""Test that module initializes correctly."""
assert isinstance(self.module.note_embed, torch.nn.Embedding)
assert isinstance(self.module.chord_encoder, torch.nn.Linear)
assert isinstance(self.module.scale_classifier, torch.nn.Linear)
assert isinstance(self.module.interval_predictor, torch.nn.Linear)
assert isinstance(self.module.progression_lstm, torch.nn.LSTM)
def test_forward_pass(self):
"""Test forward pass with dummy inputs."""
seq_len = 10
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
note_indices = torch.randint(0, 12, (self.batch_size, seq_len)) # 12 notes
output = self.module(hidden_states, note_indices)
assert "chord" in output
assert "scale" in output
assert "interval" in output
assert "progression" in output
assert output["chord"].shape == (self.batch_size, seq_len, 128)
assert output["scale"].shape == (self.batch_size, seq_len, 12)
assert output["interval"].shape == (self.batch_size, seq_len, 12)
assert output["progression"].shape == (self.batch_size, seq_len, 256)
def test_get_scale_from_key_c_major(self):
"""Test scale generation for C major."""
scale = self.module.get_scale_from_key("C", "major")
expected = ["C", "D", "E", "F", "G", "A", "B"]
assert scale == expected
def test_get_scale_from_key_a_minor(self):
"""Test scale generation for A minor (natural minor)."""
scale = self.module.get_scale_from_key("A", "natural_minor")
expected = ["A", "B", "C", "D", "E", "F", "G"]
assert scale == expected
def test_get_scale_from_key_g_mixolydian(self):
"""Test scale generation for G mixolydian."""
scale = self.module.get_scale_from_key("G", "mixolydian")
expected = ["G", "A", "B", "C", "D", "E", "F"]
assert scale == expected
def test_detect_chord_function_triad(self):
"""Test chord function detection for triads."""
# C major in C major key should be tonic (I)
function = self.module.detect_chord_function("C", "major", "C")
assert function == "I"
# F major in C major should be subdominant (IV)
function = self.module.detect_chord_function("F", "major", "C")
assert function == "IV"
# G major in C major should be dominant (V)
function = self.module.detect_chord_function("G", "major", "C")
assert function == "V"
def test_detect_chord_function_minor(self):
"""Test chord function detection for minor chords."""
# D minor in C major should be ii
function = self.module.detect_chord_function("D", "minor", "C")
assert function == "ii"
def test_get_circle_of_fifths(self):
"""Test circle of fifths generation."""
circle = self.module.get_circle_of_fifths()
assert len(circle) == 12
# First should be C (or F depending on direction)
assert "C" in circle
def test_get_modes(self):
"""Test mode names retrieval."""
modes = self.module.get_modes()
expected_modes = ["ionian", "dorian", "phrygian", "lydian", "mixolydian", "aeolian", "locrian"]
assert modes == expected_modes
def test_get_scale_for_mode(self):
"""Test getting scale for specific mode."""
scale = self.module.get_scale_for_mode("dorian", "D")
# D dorian: D E F G A B C
expected = ["D", "E", "F", "G", "A", "B", "C"]
assert scale == expected
def test_interval_to_semitones(self):
"""Test interval to semitone conversion."""
assert self.module.interval_to_semitones("P1") == 0
assert self.module.interval_to_semitones("M2") == 2
assert self.module.interval_to_semitones("M3") == 4
assert self.module.interval_to_semitones("P4") == 5
assert self.module.interval_to_semitones("P5") == 7
assert self.module.interval_to_semitones("M6") == 9
assert self.module.interval_to_semitones("M7") == 11
assert self.module.interval_to_semitones("P8") == 12
def test_semitones_to_interval(self):
"""Test semitone to interval conversion."""
assert self.module.semitones_to_interval(0) == "P1"
assert self.module.semitones_to_interval(2) == "M2"
assert self.module.semitones_to_interval(4) == "M3"
assert self.module.semitones_to_interval(5) == "P4"
assert self.module.semitones_to_interval(7) == "P5"
assert self.module.semitones_to_interval(9) == "M6"
assert self.module.semitones_to_interval(11) == "M7"
assert self.module.semitones_to_interval(12) == "P8"
def test_chord_construction_major(self):
"""Test major chord construction."""
chord = self.module.construct_chord("C", "major")
# C major: C E G
assert set(chord) == {"C", "E", "G"}
def test_chord_construction_minor(self):
"""Test minor chord construction."""
chord = self.module.construct_chord("A", "minor")
# A minor: A C E
assert set(chord) == {"A", "C", "E"}
def test_chord_construction_dominant_7(self):
"""Test dominant 7th chord construction."""
chord = self.module.construct_chord("G", "dominant7")
# G7: G B D F
assert set(chord) == {"G", "B", "D", "F"}
def test_progression_analysis(self):
"""Test chord progression analysis."""
# I-IV-V-I in C major
progression = ["C", "F", "G", "C"]
analysis = self.module.analyze_progression(progression, "C")
assert len(analysis) == 4
assert analysis[0] == "I"
assert analysis[1] == "IV"
assert analysis[2] == "V"
assert analysis[3] == "I"
def test_scale_degree_to_note(self):
"""Test converting scale degree to note."""
# In C major, scale degree 1 = C, 3 = E, 5 = G
assert self.module.scale_degree_to_note(1, "C", "major") == "C"
assert self.module.scale_degree_to_note(3, "C", "major") == "E"
assert self.module.scale_degree_to_note(5, "C", "major") == "G"
def test_note_to_scale_degree(self):
"""Test converting note to scale degree."""
# In C major, C=1, E=3, G=5
assert self.module.note_to_scale_degree("C", "C", "major") == 1
assert self.module.note_to_scale_degree("E", "C", "major") == 3
assert self.module.note_to_scale_degree("G", "C", "major") == 5
def test_relative_key(self):
"""Test relative major/minor detection."""
# C major's relative minor is A minor
assert self.module.get_relative_minor("C") == "A"
# A minor's relative major is C major
assert self.module.get_relative_major("A") == "C"
def test_parallel_key(self):
"""Test parallel major/minor."""
# C major's parallel minor is C minor
assert self.module.get_parallel_minor("C") == "C"
# A minor's parallel major is A major
assert self.module.get_parallel_major("A") == "A"
def test_forward_with_empty_sequence(self):
"""Test forward pass with empty sequence (edge case)."""
seq_len = 0
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
note_indices = torch.randint(0, 12, (self.batch_size, seq_len))
output = self.module(hidden_states, note_indices)
# Should handle empty sequence gracefully
for key in ["chord", "scale", "interval", "progression"]:
assert output[key].shape[0] == self.batch_size
assert output[key].shape[1] == seq_len
def test_different_batch_sizes(self):
"""Test forward pass with different batch sizes."""
for batch_size in [1, 2, 8]:
seq_len = 10
hidden_states = torch.randn(batch_size, seq_len, self.d_model)
note_indices = torch.randint(0, 12, (batch_size, seq_len))
output = self.module(hidden_states, note_indices)
assert output["chord"].shape[0] == batch_size
def test_gradient_flow(self):
"""Test that gradients flow through the module."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
note_indices = torch.randint(0, 12, (self.batch_size, seq_len))
output = self.module(hidden_states, note_indices)
loss = sum([out.sum() for out in output.values()])
loss.backward()
assert hidden_states.grad is not None
assert self.module.note_embed.weight.grad is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])