File size: 15,126 Bytes
4f0238f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
"""

Ear Training Module for TouchGrass.

Guides ear training exercises without audio, using descriptive language.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Dict, Tuple


class EarTrainingModule(nn.Module):
    """

    Guides ear training exercises without audio.

    

    Can:

    - Describe interval sounds in relatable terms

      ("a perfect 5th sounds like the Star Wars theme opening")

    - Generate solfege exercises (Do Re Mi Fa Sol La Ti Do)

    - Create interval identification quizzes in text form

    - Explain chord quality by ear ("major chords sound happy/bright,

      minor chords sound sad/dark, diminished chords sound tense/unstable")

    - Guide relative pitch training

    - Suggest listening exercises with specific songs/moments

    

    Tracks user progress through session context.

    """

    # Intervals (semitones)
    INTERVALS = {
        0: "unison",
        1: "minor 2nd",
        2: "major 2nd",
        3: "minor 3rd",
        4: "major 3rd",
        5: "perfect 4th",
        6: "tritone",
        7: "perfect 5th",
        8: "minor 6th",
        9: "major 6th",
        10: "minor 7th",
        11: "major 7th",
        12: "octave",
    }

    # Interval qualities
    QUALITIES = ["perfect", "major", "minor", "augmented", "diminished"]

    # Solfege syllables (movable do)
    SOLFEGE = ["Do", "Re", "Mi", "Fa", "Sol", "La", "Ti", "Do"]

    # Chord qualities and descriptions
    CHORD_DESCRIPTIONS = {
        "major": "bright, happy, stable",
        "minor": "sad, dark, melancholic",
        "diminished": "tense, unstable, dissonant",
        "augmented": "bright, dreamy, suspenseful",
        "dominant7": "bluesy, tense, wants to resolve",
        "major7": "smooth, jazzy, dreamy",
        "minor7": "smooth, soulful, mellow",
    }

    # Famous song references for intervals
    INTERVAL_SONGS = {
        0: "any note played twice",
        1: "Jaws theme (da-dum)",
        2: "Happy Birthday (2nd note)",
        3: "When the Saints Go Marching In (minor 3rd)",
        4: "Oh When the Saints (major 3rd)",
        5: "Here Comes the Bride (perfect 4th)",
        6: "The Simpsons theme (tritone)",
        7: "Star Wars theme (perfect 5th)",
        8: "My Bonnie Lies Over the Ocean (minor 6th)",
        9: "Somewhere Over the Rainbow (major 6th)",
        10: "The Office theme (minor 7th)",
        11: "Take On Me (major 7th)",
        12: "Somewhere Over the Rainbow (octave)",
    }

    def __init__(self, d_model: int):
        """

        Initialize EarTrainingModule.



        Args:

            d_model: Hidden dimension from base model

        """
        super().__init__()
        self.d_model = d_model

        # Embeddings
        self.interval_embed = nn.Embedding(13, 64)  # unison through octave
        self.quality_embed = nn.Embedding(5, 64)    # perfect/major/minor/aug/dim

        # Difficulty tracker (skill level 1-5)
        self.difficulty_tracker = nn.Linear(d_model, 5)

        # Exercise type classifier
        self.exercise_type_head = nn.Linear(d_model, 6)  # 6 exercise types

        # Interval prediction head
        self.interval_predictor = nn.Linear(d_model, 13)

        # Chord quality predictor
        self.chord_quality_predictor = nn.Linear(d_model, 7)

        # Solfege generator
        self.solfege_generator = nn.GRU(
            input_size=d_model + 64,
            hidden_size=d_model,
            num_layers=1,
            batch_first=True,
        )

        # Progress tracker (simple RNN to track session history)
        self.progress_tracker = nn.GRU(
            input_size=5,  # one-hot for exercise types
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )

        # Success rate predictor
        self.success_predictor = nn.Linear(64, 1)

    def forward(

        self,

        hidden_states: torch.Tensor,

        exercise_type: Optional[int] = None,

        user_response: Optional[str] = None,

    ) -> Dict[str, torch.Tensor]:
        """

        Forward pass through EarTrainingModule.



        Args:

            hidden_states: Base model hidden states [batch, seq_len, d_model]

            exercise_type: Optional exercise type ID (0-5)

            user_response: Optional user's answer for progress tracking



        Returns:

            Dictionary with ear training predictions

        """
        batch_size, seq_len, _ = hidden_states.shape

        # Pool hidden states
        pooled = hidden_states.mean(dim=1)  # [batch, d_model]

        # Predict difficulty level
        difficulty_logits = self.difficulty_tracker(pooled)  # [batch, 5]

        # Predict exercise type
        exercise_logits = self.exercise_type_head(pooled)  # [batch, 6]

        # Predict interval
        interval_logits = self.interval_predictor(pooled)  # [batch, 13]

        # Predict chord quality
        chord_quality_logits = self.chord_quality_predictor(pooled)  # [batch, 7]

        outputs = {
            "difficulty_logits": difficulty_logits,
            "exercise_type_logits": exercise_logits,
            "interval_logits": interval_logits,
            "chord_quality_logits": chord_quality_logits,
        }

        return outputs

    def describe_interval(self, interval_semitones: int, reference: str = "song") -> str:
        """

        Describe an interval in relatable terms.



        Args:

            interval_semitones: Number of semitones (0-12)

            reference: Type of reference ("song", "emotion", "technical")



        Returns:

            Descriptive string

        """
        if interval_semitones not in self.INTERVALS:
            return f"Unknown interval: {interval_semitones} semitones"

        interval_name = self.INTERVALS[interval_semitones]

        if reference == "song":
            song = self.INTERVAL_SONGS.get(interval_semitones, "a generic interval")
            return f"A {interval_name} ({interval_semitones} semitones) — like {song}."
        elif reference == "emotion":
            # Map intervals to emotional descriptors
            emotion_map = {
                0: "familiar, consonant",
                1: "tense, dissonant",
                2: "slightly tense",
                3: "sad, soulful",
                4: "bright, happy",
                5: "stable, resolved",
                6: "very tense, mysterious",
                7: "strong, stable",
                8: "sweet, melancholic",
                9: "bright, hopeful",
                10: "bluesy, tense",
                11: "smooth, jazzy",
                12: "complete, resolved",
            }
            emotion = emotion_map.get(interval_semitones, "neutral")
            return f"A {interval_name} feels {emotion}."
        else:
            return f"A {interval_name} spans {interval_semitones} semitones."

    def generate_solfege_exercise(

        self,

        key: str = "C",

        difficulty: int = 1,

        num_notes: int = 5,

    ) -> List[str]:
        """

        Generate solfege exercise.



        Args:

            key: Key signature (affects accidentals)

            difficulty: 1-5, higher = more accidentals, larger jumps

            num_notes: Number of notes in exercise



        Returns:

            List of solfege syllables

        """
        import random

        # Simple pentatonic scale for low difficulty
        if difficulty <= 2:
            # Stepwise motion, no accidentals
            start_idx = random.randint(0, 4)  # Do to Sol
            exercise = []
            for i in range(num_notes):
                idx = (start_idx + i) % 7
                exercise.append(self.SOLFEGE[idx])
            return exercise
        else:
            # More complex: wider leaps, accidentals
            exercise = []
            current = 0  # Start at Do
            for _ in range(num_notes):
                # Jump size increases with difficulty
                max_jump = min(difficulty + 2, 7)
                jump = random.randint(-max_jump, max_jump)
                current = max(0, min(6, current + jump))
                exercise.append(self.SOLFEGE[current])
            return exercise

    def generate_interval_quiz(

        self,

        num_questions: int = 5,

        max_interval: int = 12,

        include_desc: bool = True,

    ) -> List[Dict]:
        """

        Generate interval identification quiz.



        Args:

            num_questions: Number of questions

            max_interval: Maximum interval size (up to 12)

            include_desc: Include descriptive hints



        Returns:

            List of quiz questions

        """
        import random

        questions = []
        for _ in range(num_questions):
            interval = random.randint(1, max_interval)
            quality = "perfect" if interval in [1, 4, 5, 8, 11, 12] else random.choice(["major", "minor"])

            question = {
                "interval_semitones": interval,
                "interval_name": self.INTERVALS[interval],
                "quality": quality,
            }

            if include_desc:
                question["hint"] = self.describe_interval(interval, reference="song")

            questions.append(question)

        return questions

    def describe_chord_quality(self, chord_type: str) -> str:
        """

        Describe how a chord quality sounds.



        Args:

            chord_type: Chord type (major, minor, etc)



        Returns:

            Descriptive string

        """
        description = self.CHORD_DESCRIPTIONS.get(chord_type, "unique sounding")
        return f"{chord_type} chords sound {description}."

    def suggest_listening_exercise(

        self,

        interval: Optional[int] = None,

        chord_quality: Optional[str] = None,

    ) -> Dict[str, str]:
        """

        Suggest specific songs/moments to listen for intervals or chords.



        Args:

            interval: Optional specific interval to practice

            chord_quality: Optional chord quality to practice



        Returns:

            Dictionary with listening suggestions

        """
        suggestions = {}

        if interval:
            song = self.INTERVAL_SONGS.get(interval, "various songs")
            suggestions["interval"] = f"Listen for {self.INTERVALS[interval]} in: {song}"
            suggestions["tip"] = "Try to hum along to internalize the sound."

        if chord_quality:
            # Provide famous examples
            examples = {
                "major": ["Happy Birthday", "Let It Be (chorus)"],
                "minor": ["House of the Rising Sun", "Greensleeves"],
                "diminished": ["The Simpsons theme (tritone)"],
                "dominant7": ["Blues progressions", "Purple Haze"],
                "major7": ["Something (The Beatles)", "So What (Miles Davis)"],
            }
            songs = examples.get(chord_quality, ["various songs"])
            suggestions["chord"] = f"Listen for {chord_quality} chords in: {', '.join(songs)}"
            suggestions["tip"] = "Focus on the emotional character."

        return suggestions

    def track_progress(

        self,

        exercise_history: List[Dict],

        current_performance: float,

    ) -> Dict[str, any]:
        """

        Track user's progress over session.



        Args:

            exercise_history: List of past exercises with scores

            current_performance: Current success rate (0-1)



        Returns:

            Progress analysis

        """
        if not exercise_history:
            return {"level": "beginner", "suggestion": "Start with interval identification"}

        # Calculate average performance
        avg_performance = sum(ex.get("score", 0) for ex in exercise_history) / len(exercise_history)

        # Determine level
        if avg_performance < 0.5:
            level = "beginner"
            suggestion = "Practice more interval identification with smaller intervals (2nd-5th)."
        elif avg_performance < 0.7:
            level = "intermediate"
            suggestion = "Try more complex intervals and chord qualities."
        else:
            level = "advanced"
            suggestion = "Challenge yourself with inversions and advanced chords."

        return {
            "level": level,
            "average_score": avg_performance,
            "current_score": current_performance,
            "suggestion": suggestion,
            "exercises_completed": len(exercise_history),
        }


def test_ear_training_module():
    """Test the EarTrainingModule."""
    import torch

    # Create module
    module = EarTrainingModule(d_model=4096)

    # Test input
    batch_size = 2
    seq_len = 10
    d_model = 4096
    hidden_states = torch.randn(batch_size, seq_len, d_model)

    # Forward pass
    outputs = module.forward(hidden_states)

    print("Ear Training Module outputs:")
    for key, value in outputs.items():
        print(f"  {key}: {value.shape}")

    # Test interval description
    print("\nInterval descriptions:")
    for semitones in [3, 4, 5, 7, 10]:
        desc = module.describe_interval(semitones, reference="song")
        print(f"  {semitones} semitones: {desc}")

    # Test solfege exercise
    print("\nSolfege exercise (C, difficulty 2):")
    solfege = module.generate_solfege_exercise(key="C", difficulty=2, num_notes=8)
    print(f"  {' '.join(solfege)}")

    # Test interval quiz
    print("\nInterval quiz (3 questions):")
    quiz = module.generate_interval_quiz(num_questions=3)
    for i, q in enumerate(quiz):
        print(f"  Q{i+1}: {q['interval_name']} ({q['interval_semitones']} semitones)")
        if 'hint' in q:
            print(f"      Hint: {q['hint']}")

    # Test chord description
    print("\nChord quality descriptions:")
    for chord in ["major", "minor", "diminished", "major7"]:
        desc = module.describe_chord_quality(chord)
        print(f"  {chord}: {desc}")

    # Test listening suggestions
    print("\nListening exercise suggestions:")
    suggestions = module.suggest_listening_exercise(interval=7, chord_quality="major")
    for key, value in suggestions.items():
        print(f"  {key}: {value}")

    # Test progress tracking
    print("\nProgress tracking:")
    history = [
        {"exercise": "interval", "score": 0.6},
        {"exercise": "interval", "score": 0.7},
        {"exercise": "chord", "score": 0.5},
    ]
    progress = module.track_progress(history, current_performance=0.8)
    for key, value in progress.items():
        print(f"  {key}: {value}")

    print("\nEar Training Module test complete!")


if __name__ == "__main__":
    test_ear_training_module()