File size: 9,270 Bytes
4f0238f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""

Tests for Ear Training Module.

"""

import pytest
import torch

from TouchGrass.models.ear_training_module import EarTrainingModule


class TestEarTrainingModule:
    """Test suite for EarTrainingModule."""

    def setup_method(self):
        """Set up test fixtures."""
        self.d_model = 768
        self.batch_size = 4
        self.module = EarTrainingModule(d_model=self.d_model)

    def test_module_initialization(self):
        """Test that module initializes correctly."""
        assert isinstance(self.module.interval_embed, torch.nn.Embedding)
        assert isinstance(self.module.interval_classifier, torch.nn.Linear)
        assert isinstance(self.module.solfege_embed, torch.nn.Embedding)
        assert isinstance(self.module.solfege_generator, torch.nn.LSTM)
        assert isinstance(self.module.quiz_lstm, torch.nn.LSTM)
        assert isinstance(self.module.quiz_head, torch.nn.Linear)

    def test_forward_pass(self):
        """Test forward pass with dummy inputs."""
        seq_len = 10
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
        interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))  # 12 intervals

        output = self.module(hidden_states, interval_ids)

        assert "interval_logits" in output
        assert "solfege" in output
        assert "quiz_questions" in output
        assert output["interval_logits"].shape == (self.batch_size, seq_len, 12)
        assert output["solfege"].shape[0] == self.batch_size
        assert output["solfege"].shape[1] == seq_len
        assert output["quiz_questions"].shape[0] == self.batch_size
        assert output["quiz_questions"].shape[1] == seq_len

    def test_get_interval_name(self):
        """Test interval name retrieval."""
        assert self.module.get_interval_name(0) == "P1"  # Perfect unison
        assert self.module.get_interval_name(2) == "M2"  # Major 2nd
        assert self.module.get_interval_name(4) == "M3"  # Major 3rd
        assert self.module.get_interval_name(7) == "P5"  # Perfect 5th
        assert self.module.get_interval_name(12) == "P8"  # Perfect octave

    def test_get_song_reference(self):
        """Test song reference retrieval for intervals."""
        # Perfect 5th - Star Wars
        p5_refs = self.module.get_song_reference("P5")
        assert "Star Wars" in p5_refs or "star wars" in p5_refs.lower()

        # Minor 2nd - Jaws
        m2_refs = self.module.get_song_reference("m2")
        assert "Jaws" in m2_refs or "jaws" in m2_refs.lower()

        # Major 3rd - When the Saints
        M3_refs = self.module.get_song_reference("M3")
        assert "Saints" in M3_refs or "saints" in M3_refs.lower()

    def test_generate_solfege_exercise(self):
        """Test solfege exercise generation."""
        exercise = self.module.generate_solfege_exercise(difficulty="beginner", key="C")
        assert "exercise" in exercise or "notes" in exercise
        assert "key" in exercise or "C" in str(exercise)

    def test_generate_interval_quiz(self):
        """Test interval quiz generation."""
        quiz = self.module.generate_interval_quiz(num_questions=5, difficulty="medium")
        assert "questions" in quiz
        assert len(quiz["questions"]) == 5

    def test_describe_interval(self):
        """Test interval description with song reference."""
        description = self.module.describe_interval(7)  # Perfect 5th
        assert "7 semitones" in description or "perfect fifth" in description.lower()
        assert "Star Wars" in description or "star wars" in description.lower()

    def test_get_solfege_syllables(self):
        """Test solfege syllable retrieval."""
        syllables = self.module.get_solfege_syllables(key="C", mode="major")
        expected = ["Do", "Re", "Mi", "Fa", "So", "La", "Ti", "Do"]
        assert syllables == expected

    def test_get_solfege_syllables_minor(self):
        """Test solfege syllables for minor mode."""
        syllables = self.module.get_solfege_syllables(key="A", mode="minor")
        # Minor solfege: Do Re Me Fa Se Le Te Do (or variations)
        assert "Do" in syllables
        assert len(syllables) >= 7

    def test_interval_to_name(self):
        """Test converting semitone count to interval name."""
        assert self.module.interval_to_name(0) == "P1"
        assert self.module.interval_to_name(1) == "m2"
        assert self.module.interval_to_name(2) == "M2"
        assert self.module.interval_to_name(3) == "m3"
        assert self.module.interval_to_name(4) == "M3"
        assert self.module.interval_to_name(5) == "P4"
        assert self.module.interval_to_name(6) == "TT"  # Tritone
        assert self.module.interval_to_name(7) == "P5"
        assert self.module.interval_to_name(11) == "M7"
        assert self.module.interval_to_name(12) == "P8"

    def test_name_to_interval(self):
        """Test converting interval name to semitone count."""
        assert self.module.name_to_interval("P1") == 0
        assert self.module.name_to_interval("m2") == 1
        assert self.module.name_to_interval("M2") == 2
        assert self.module.name_to_interval("M3") == 4
        assert self.module.name_to_interval("P4") == 5
        assert self.module.name_to_interval("P5") == 7
        assert self.module.name_to_interval("P8") == 12

    def test_quiz_question_format(self):
        """Test that quiz questions are properly formatted."""
        quiz = self.module.generate_interval_quiz(num_questions=3, difficulty="easy")
        for question in quiz["questions"]:
            assert "question" in question
            assert "answer" in question
            assert "options" in question or isinstance(question["answer"], (str, int))

    def test_solfege_output_length(self):
        """Test solfege output has correct sequence length."""
        seq_len = 10
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
        interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))

        output = self.module(hidden_states, interval_ids)
        solfege_seq_len = output["solfege"].shape[1]
        assert solfege_seq_len == seq_len

    def test_different_batch_sizes(self):
        """Test forward pass with different batch sizes."""
        for batch_size in [1, 2, 8]:
            seq_len = 10
            hidden_states = torch.randn(batch_size, seq_len, self.d_model)
            interval_ids = torch.randint(0, 12, (batch_size, seq_len))

            output = self.module(hidden_states, interval_ids)
            assert output["interval_logits"].shape[0] == batch_size

    def test_gradient_flow(self):
        """Test that gradients flow through the module."""
        seq_len = 5
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
        interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))

        output = self.module(hidden_states, interval_ids)
        loss = output["interval_logits"].sum() + output["solfege"].sum()
        loss.backward()

        assert hidden_states.grad is not None
        assert self.module.interval_embed.weight.grad is not None

    def test_interval_classifier_output(self):
        """Test interval classifier produces logits for all intervals."""
        seq_len = 1
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
        interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))

        output = self.module(hidden_states, interval_ids)
        logits = output["interval_logits"]

        # Should have logits for 12 intervals (0-11 semitones)
        assert logits.shape[-1] == 12

    def test_quiz_head_output(self):
        """Test quiz head produces appropriate output."""
        seq_len = 1
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
        interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))

        output = self.module(hidden_states, interval_ids)
        quiz_output = output["quiz_questions"]

        # Quiz output should have some dimension for question generation
        assert quiz_output.shape[0] == self.batch_size
        assert quiz_output.shape[1] == seq_len

    def test_song_reference_coverage(self):
        """Test that common intervals have song references."""
        common_intervals = [0, 2, 4, 5, 7, 9, 12]  # P1, M2, M3, P4, P5, M6, P8
        for interval in common_intervals:
            name = self.module.interval_to_name(interval)
            refs = self.module.get_song_reference(name)
            assert len(refs) > 0, f"No song reference for interval {name}"

    def test_musical_accuracy(self):
        """Test musical accuracy of interval calculations."""
        # Test all intervals from 0 to 12
        for semitones in range(13):
            name = self.module.interval_to_name(semitones)
            converted_back = self.module.name_to_interval(name)
            assert converted_back == semitones, f"Round-trip failed for {semitones} ({name})"


if __name__ == "__main__":
    pytest.main([__file__, "-v"])