| | """
|
| | Ear Training Module for TouchGrass.
|
| | Guides ear training exercises without audio, using descriptive language.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional, List, Dict, Tuple
|
| |
|
| |
|
| | class EarTrainingModule(nn.Module):
|
| | """
|
| | Guides ear training exercises without audio.
|
| |
|
| | Can:
|
| | - Describe interval sounds in relatable terms
|
| | ("a perfect 5th sounds like the Star Wars theme opening")
|
| | - Generate solfege exercises (Do Re Mi Fa Sol La Ti Do)
|
| | - Create interval identification quizzes in text form
|
| | - Explain chord quality by ear ("major chords sound happy/bright,
|
| | minor chords sound sad/dark, diminished chords sound tense/unstable")
|
| | - Guide relative pitch training
|
| | - Suggest listening exercises with specific songs/moments
|
| |
|
| | Tracks user progress through session context.
|
| | """
|
| |
|
| |
|
| | INTERVALS = {
|
| | 0: "unison",
|
| | 1: "minor 2nd",
|
| | 2: "major 2nd",
|
| | 3: "minor 3rd",
|
| | 4: "major 3rd",
|
| | 5: "perfect 4th",
|
| | 6: "tritone",
|
| | 7: "perfect 5th",
|
| | 8: "minor 6th",
|
| | 9: "major 6th",
|
| | 10: "minor 7th",
|
| | 11: "major 7th",
|
| | 12: "octave",
|
| | }
|
| |
|
| |
|
| | QUALITIES = ["perfect", "major", "minor", "augmented", "diminished"]
|
| |
|
| |
|
| | SOLFEGE = ["Do", "Re", "Mi", "Fa", "Sol", "La", "Ti", "Do"]
|
| |
|
| |
|
| | CHORD_DESCRIPTIONS = {
|
| | "major": "bright, happy, stable",
|
| | "minor": "sad, dark, melancholic",
|
| | "diminished": "tense, unstable, dissonant",
|
| | "augmented": "bright, dreamy, suspenseful",
|
| | "dominant7": "bluesy, tense, wants to resolve",
|
| | "major7": "smooth, jazzy, dreamy",
|
| | "minor7": "smooth, soulful, mellow",
|
| | }
|
| |
|
| |
|
| | INTERVAL_SONGS = {
|
| | 0: "any note played twice",
|
| | 1: "Jaws theme (da-dum)",
|
| | 2: "Happy Birthday (2nd note)",
|
| | 3: "When the Saints Go Marching In (minor 3rd)",
|
| | 4: "Oh When the Saints (major 3rd)",
|
| | 5: "Here Comes the Bride (perfect 4th)",
|
| | 6: "The Simpsons theme (tritone)",
|
| | 7: "Star Wars theme (perfect 5th)",
|
| | 8: "My Bonnie Lies Over the Ocean (minor 6th)",
|
| | 9: "Somewhere Over the Rainbow (major 6th)",
|
| | 10: "The Office theme (minor 7th)",
|
| | 11: "Take On Me (major 7th)",
|
| | 12: "Somewhere Over the Rainbow (octave)",
|
| | }
|
| |
|
| | def __init__(self, d_model: int):
|
| | """
|
| | Initialize EarTrainingModule.
|
| |
|
| | Args:
|
| | d_model: Hidden dimension from base model
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| |
|
| |
|
| | self.interval_embed = nn.Embedding(13, 64)
|
| | self.quality_embed = nn.Embedding(5, 64)
|
| |
|
| |
|
| | self.difficulty_tracker = nn.Linear(d_model, 5)
|
| |
|
| |
|
| | self.exercise_type_head = nn.Linear(d_model, 6)
|
| |
|
| |
|
| | self.interval_predictor = nn.Linear(d_model, 13)
|
| |
|
| |
|
| | self.chord_quality_predictor = nn.Linear(d_model, 7)
|
| |
|
| |
|
| | self.solfege_generator = nn.GRU(
|
| | input_size=d_model + 64,
|
| | hidden_size=d_model,
|
| | num_layers=1,
|
| | batch_first=True,
|
| | )
|
| |
|
| |
|
| | self.progress_tracker = nn.GRU(
|
| | input_size=5,
|
| | hidden_size=64,
|
| | num_layers=1,
|
| | batch_first=True,
|
| | )
|
| |
|
| |
|
| | self.success_predictor = nn.Linear(64, 1)
|
| |
|
| | def forward(
|
| | self,
|
| | hidden_states: torch.Tensor,
|
| | exercise_type: Optional[int] = None,
|
| | user_response: Optional[str] = None,
|
| | ) -> Dict[str, torch.Tensor]:
|
| | """
|
| | Forward pass through EarTrainingModule.
|
| |
|
| | Args:
|
| | hidden_states: Base model hidden states [batch, seq_len, d_model]
|
| | exercise_type: Optional exercise type ID (0-5)
|
| | user_response: Optional user's answer for progress tracking
|
| |
|
| | Returns:
|
| | Dictionary with ear training predictions
|
| | """
|
| | batch_size, seq_len, _ = hidden_states.shape
|
| |
|
| |
|
| | pooled = hidden_states.mean(dim=1)
|
| |
|
| |
|
| | difficulty_logits = self.difficulty_tracker(pooled)
|
| |
|
| |
|
| | exercise_logits = self.exercise_type_head(pooled)
|
| |
|
| |
|
| | interval_logits = self.interval_predictor(pooled)
|
| |
|
| |
|
| | chord_quality_logits = self.chord_quality_predictor(pooled)
|
| |
|
| | outputs = {
|
| | "difficulty_logits": difficulty_logits,
|
| | "exercise_type_logits": exercise_logits,
|
| | "interval_logits": interval_logits,
|
| | "chord_quality_logits": chord_quality_logits,
|
| | }
|
| |
|
| | return outputs
|
| |
|
| | def describe_interval(self, interval_semitones: int, reference: str = "song") -> str:
|
| | """
|
| | Describe an interval in relatable terms.
|
| |
|
| | Args:
|
| | interval_semitones: Number of semitones (0-12)
|
| | reference: Type of reference ("song", "emotion", "technical")
|
| |
|
| | Returns:
|
| | Descriptive string
|
| | """
|
| | if interval_semitones not in self.INTERVALS:
|
| | return f"Unknown interval: {interval_semitones} semitones"
|
| |
|
| | interval_name = self.INTERVALS[interval_semitones]
|
| |
|
| | if reference == "song":
|
| | song = self.INTERVAL_SONGS.get(interval_semitones, "a generic interval")
|
| | return f"A {interval_name} ({interval_semitones} semitones) — like {song}."
|
| | elif reference == "emotion":
|
| |
|
| | emotion_map = {
|
| | 0: "familiar, consonant",
|
| | 1: "tense, dissonant",
|
| | 2: "slightly tense",
|
| | 3: "sad, soulful",
|
| | 4: "bright, happy",
|
| | 5: "stable, resolved",
|
| | 6: "very tense, mysterious",
|
| | 7: "strong, stable",
|
| | 8: "sweet, melancholic",
|
| | 9: "bright, hopeful",
|
| | 10: "bluesy, tense",
|
| | 11: "smooth, jazzy",
|
| | 12: "complete, resolved",
|
| | }
|
| | emotion = emotion_map.get(interval_semitones, "neutral")
|
| | return f"A {interval_name} feels {emotion}."
|
| | else:
|
| | return f"A {interval_name} spans {interval_semitones} semitones."
|
| |
|
| | def generate_solfege_exercise(
|
| | self,
|
| | key: str = "C",
|
| | difficulty: int = 1,
|
| | num_notes: int = 5,
|
| | ) -> List[str]:
|
| | """
|
| | Generate solfege exercise.
|
| |
|
| | Args:
|
| | key: Key signature (affects accidentals)
|
| | difficulty: 1-5, higher = more accidentals, larger jumps
|
| | num_notes: Number of notes in exercise
|
| |
|
| | Returns:
|
| | List of solfege syllables
|
| | """
|
| | import random
|
| |
|
| |
|
| | if difficulty <= 2:
|
| |
|
| | start_idx = random.randint(0, 4)
|
| | exercise = []
|
| | for i in range(num_notes):
|
| | idx = (start_idx + i) % 7
|
| | exercise.append(self.SOLFEGE[idx])
|
| | return exercise
|
| | else:
|
| |
|
| | exercise = []
|
| | current = 0
|
| | for _ in range(num_notes):
|
| |
|
| | max_jump = min(difficulty + 2, 7)
|
| | jump = random.randint(-max_jump, max_jump)
|
| | current = max(0, min(6, current + jump))
|
| | exercise.append(self.SOLFEGE[current])
|
| | return exercise
|
| |
|
| | def generate_interval_quiz(
|
| | self,
|
| | num_questions: int = 5,
|
| | max_interval: int = 12,
|
| | include_desc: bool = True,
|
| | ) -> List[Dict]:
|
| | """
|
| | Generate interval identification quiz.
|
| |
|
| | Args:
|
| | num_questions: Number of questions
|
| | max_interval: Maximum interval size (up to 12)
|
| | include_desc: Include descriptive hints
|
| |
|
| | Returns:
|
| | List of quiz questions
|
| | """
|
| | import random
|
| |
|
| | questions = []
|
| | for _ in range(num_questions):
|
| | interval = random.randint(1, max_interval)
|
| | quality = "perfect" if interval in [1, 4, 5, 8, 11, 12] else random.choice(["major", "minor"])
|
| |
|
| | question = {
|
| | "interval_semitones": interval,
|
| | "interval_name": self.INTERVALS[interval],
|
| | "quality": quality,
|
| | }
|
| |
|
| | if include_desc:
|
| | question["hint"] = self.describe_interval(interval, reference="song")
|
| |
|
| | questions.append(question)
|
| |
|
| | return questions
|
| |
|
| | def describe_chord_quality(self, chord_type: str) -> str:
|
| | """
|
| | Describe how a chord quality sounds.
|
| |
|
| | Args:
|
| | chord_type: Chord type (major, minor, etc)
|
| |
|
| | Returns:
|
| | Descriptive string
|
| | """
|
| | description = self.CHORD_DESCRIPTIONS.get(chord_type, "unique sounding")
|
| | return f"{chord_type} chords sound {description}."
|
| |
|
| | def suggest_listening_exercise(
|
| | self,
|
| | interval: Optional[int] = None,
|
| | chord_quality: Optional[str] = None,
|
| | ) -> Dict[str, str]:
|
| | """
|
| | Suggest specific songs/moments to listen for intervals or chords.
|
| |
|
| | Args:
|
| | interval: Optional specific interval to practice
|
| | chord_quality: Optional chord quality to practice
|
| |
|
| | Returns:
|
| | Dictionary with listening suggestions
|
| | """
|
| | suggestions = {}
|
| |
|
| | if interval:
|
| | song = self.INTERVAL_SONGS.get(interval, "various songs")
|
| | suggestions["interval"] = f"Listen for {self.INTERVALS[interval]} in: {song}"
|
| | suggestions["tip"] = "Try to hum along to internalize the sound."
|
| |
|
| | if chord_quality:
|
| |
|
| | examples = {
|
| | "major": ["Happy Birthday", "Let It Be (chorus)"],
|
| | "minor": ["House of the Rising Sun", "Greensleeves"],
|
| | "diminished": ["The Simpsons theme (tritone)"],
|
| | "dominant7": ["Blues progressions", "Purple Haze"],
|
| | "major7": ["Something (The Beatles)", "So What (Miles Davis)"],
|
| | }
|
| | songs = examples.get(chord_quality, ["various songs"])
|
| | suggestions["chord"] = f"Listen for {chord_quality} chords in: {', '.join(songs)}"
|
| | suggestions["tip"] = "Focus on the emotional character."
|
| |
|
| | return suggestions
|
| |
|
| | def track_progress(
|
| | self,
|
| | exercise_history: List[Dict],
|
| | current_performance: float,
|
| | ) -> Dict[str, any]:
|
| | """
|
| | Track user's progress over session.
|
| |
|
| | Args:
|
| | exercise_history: List of past exercises with scores
|
| | current_performance: Current success rate (0-1)
|
| |
|
| | Returns:
|
| | Progress analysis
|
| | """
|
| | if not exercise_history:
|
| | return {"level": "beginner", "suggestion": "Start with interval identification"}
|
| |
|
| |
|
| | avg_performance = sum(ex.get("score", 0) for ex in exercise_history) / len(exercise_history)
|
| |
|
| |
|
| | if avg_performance < 0.5:
|
| | level = "beginner"
|
| | suggestion = "Practice more interval identification with smaller intervals (2nd-5th)."
|
| | elif avg_performance < 0.7:
|
| | level = "intermediate"
|
| | suggestion = "Try more complex intervals and chord qualities."
|
| | else:
|
| | level = "advanced"
|
| | suggestion = "Challenge yourself with inversions and advanced chords."
|
| |
|
| | return {
|
| | "level": level,
|
| | "average_score": avg_performance,
|
| | "current_score": current_performance,
|
| | "suggestion": suggestion,
|
| | "exercises_completed": len(exercise_history),
|
| | }
|
| |
|
| |
|
| | def test_ear_training_module():
|
| | """Test the EarTrainingModule."""
|
| | import torch
|
| |
|
| |
|
| | module = EarTrainingModule(d_model=4096)
|
| |
|
| |
|
| | batch_size = 2
|
| | seq_len = 10
|
| | d_model = 4096
|
| | hidden_states = torch.randn(batch_size, seq_len, d_model)
|
| |
|
| |
|
| | outputs = module.forward(hidden_states)
|
| |
|
| | print("Ear Training Module outputs:")
|
| | for key, value in outputs.items():
|
| | print(f" {key}: {value.shape}")
|
| |
|
| |
|
| | print("\nInterval descriptions:")
|
| | for semitones in [3, 4, 5, 7, 10]:
|
| | desc = module.describe_interval(semitones, reference="song")
|
| | print(f" {semitones} semitones: {desc}")
|
| |
|
| |
|
| | print("\nSolfege exercise (C, difficulty 2):")
|
| | solfege = module.generate_solfege_exercise(key="C", difficulty=2, num_notes=8)
|
| | print(f" {' '.join(solfege)}")
|
| |
|
| |
|
| | print("\nInterval quiz (3 questions):")
|
| | quiz = module.generate_interval_quiz(num_questions=3)
|
| | for i, q in enumerate(quiz):
|
| | print(f" Q{i+1}: {q['interval_name']} ({q['interval_semitones']} semitones)")
|
| | if 'hint' in q:
|
| | print(f" Hint: {q['hint']}")
|
| |
|
| |
|
| | print("\nChord quality descriptions:")
|
| | for chord in ["major", "minor", "diminished", "major7"]:
|
| | desc = module.describe_chord_quality(chord)
|
| | print(f" {chord}: {desc}")
|
| |
|
| |
|
| | print("\nListening exercise suggestions:")
|
| | suggestions = module.suggest_listening_exercise(interval=7, chord_quality="major")
|
| | for key, value in suggestions.items():
|
| | print(f" {key}: {value}")
|
| |
|
| |
|
| | print("\nProgress tracking:")
|
| | history = [
|
| | {"exercise": "interval", "score": 0.6},
|
| | {"exercise": "interval", "score": 0.7},
|
| | {"exercise": "chord", "score": 0.5},
|
| | ]
|
| | progress = module.track_progress(history, current_performance=0.8)
|
| | for key, value in progress.items():
|
| | print(f" {key}: {value}")
|
| |
|
| | print("\nEar Training Module test complete!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_ear_training_module() |