""" Ear Training Module for TouchGrass. Guides ear training exercises without audio, using descriptive language. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List, Dict, Tuple class EarTrainingModule(nn.Module): """ Guides ear training exercises without audio. Can: - Describe interval sounds in relatable terms ("a perfect 5th sounds like the Star Wars theme opening") - Generate solfege exercises (Do Re Mi Fa Sol La Ti Do) - Create interval identification quizzes in text form - Explain chord quality by ear ("major chords sound happy/bright, minor chords sound sad/dark, diminished chords sound tense/unstable") - Guide relative pitch training - Suggest listening exercises with specific songs/moments Tracks user progress through session context. """ # Intervals (semitones) INTERVALS = { 0: "unison", 1: "minor 2nd", 2: "major 2nd", 3: "minor 3rd", 4: "major 3rd", 5: "perfect 4th", 6: "tritone", 7: "perfect 5th", 8: "minor 6th", 9: "major 6th", 10: "minor 7th", 11: "major 7th", 12: "octave", } # Interval qualities QUALITIES = ["perfect", "major", "minor", "augmented", "diminished"] # Solfege syllables (movable do) SOLFEGE = ["Do", "Re", "Mi", "Fa", "Sol", "La", "Ti", "Do"] # Chord qualities and descriptions CHORD_DESCRIPTIONS = { "major": "bright, happy, stable", "minor": "sad, dark, melancholic", "diminished": "tense, unstable, dissonant", "augmented": "bright, dreamy, suspenseful", "dominant7": "bluesy, tense, wants to resolve", "major7": "smooth, jazzy, dreamy", "minor7": "smooth, soulful, mellow", } # Famous song references for intervals INTERVAL_SONGS = { 0: "any note played twice", 1: "Jaws theme (da-dum)", 2: "Happy Birthday (2nd note)", 3: "When the Saints Go Marching In (minor 3rd)", 4: "Oh When the Saints (major 3rd)", 5: "Here Comes the Bride (perfect 4th)", 6: "The Simpsons theme (tritone)", 7: "Star Wars theme (perfect 5th)", 8: "My Bonnie Lies Over the Ocean (minor 6th)", 9: "Somewhere Over the Rainbow (major 6th)", 10: "The Office theme (minor 7th)", 11: "Take On Me (major 7th)", 12: "Somewhere Over the Rainbow (octave)", } def __init__(self, d_model: int): """ Initialize EarTrainingModule. Args: d_model: Hidden dimension from base model """ super().__init__() self.d_model = d_model # Embeddings self.interval_embed = nn.Embedding(13, 64) # unison through octave self.quality_embed = nn.Embedding(5, 64) # perfect/major/minor/aug/dim # Difficulty tracker (skill level 1-5) self.difficulty_tracker = nn.Linear(d_model, 5) # Exercise type classifier self.exercise_type_head = nn.Linear(d_model, 6) # 6 exercise types # Interval prediction head self.interval_predictor = nn.Linear(d_model, 13) # Chord quality predictor self.chord_quality_predictor = nn.Linear(d_model, 7) # Solfege generator self.solfege_generator = nn.GRU( input_size=d_model + 64, hidden_size=d_model, num_layers=1, batch_first=True, ) # Progress tracker (simple RNN to track session history) self.progress_tracker = nn.GRU( input_size=5, # one-hot for exercise types hidden_size=64, num_layers=1, batch_first=True, ) # Success rate predictor self.success_predictor = nn.Linear(64, 1) def forward( self, hidden_states: torch.Tensor, exercise_type: Optional[int] = None, user_response: Optional[str] = None, ) -> Dict[str, torch.Tensor]: """ Forward pass through EarTrainingModule. Args: hidden_states: Base model hidden states [batch, seq_len, d_model] exercise_type: Optional exercise type ID (0-5) user_response: Optional user's answer for progress tracking Returns: Dictionary with ear training predictions """ batch_size, seq_len, _ = hidden_states.shape # Pool hidden states pooled = hidden_states.mean(dim=1) # [batch, d_model] # Predict difficulty level difficulty_logits = self.difficulty_tracker(pooled) # [batch, 5] # Predict exercise type exercise_logits = self.exercise_type_head(pooled) # [batch, 6] # Predict interval interval_logits = self.interval_predictor(pooled) # [batch, 13] # Predict chord quality chord_quality_logits = self.chord_quality_predictor(pooled) # [batch, 7] outputs = { "difficulty_logits": difficulty_logits, "exercise_type_logits": exercise_logits, "interval_logits": interval_logits, "chord_quality_logits": chord_quality_logits, } return outputs def describe_interval(self, interval_semitones: int, reference: str = "song") -> str: """ Describe an interval in relatable terms. Args: interval_semitones: Number of semitones (0-12) reference: Type of reference ("song", "emotion", "technical") Returns: Descriptive string """ if interval_semitones not in self.INTERVALS: return f"Unknown interval: {interval_semitones} semitones" interval_name = self.INTERVALS[interval_semitones] if reference == "song": song = self.INTERVAL_SONGS.get(interval_semitones, "a generic interval") return f"A {interval_name} ({interval_semitones} semitones) — like {song}." elif reference == "emotion": # Map intervals to emotional descriptors emotion_map = { 0: "familiar, consonant", 1: "tense, dissonant", 2: "slightly tense", 3: "sad, soulful", 4: "bright, happy", 5: "stable, resolved", 6: "very tense, mysterious", 7: "strong, stable", 8: "sweet, melancholic", 9: "bright, hopeful", 10: "bluesy, tense", 11: "smooth, jazzy", 12: "complete, resolved", } emotion = emotion_map.get(interval_semitones, "neutral") return f"A {interval_name} feels {emotion}." else: return f"A {interval_name} spans {interval_semitones} semitones." def generate_solfege_exercise( self, key: str = "C", difficulty: int = 1, num_notes: int = 5, ) -> List[str]: """ Generate solfege exercise. Args: key: Key signature (affects accidentals) difficulty: 1-5, higher = more accidentals, larger jumps num_notes: Number of notes in exercise Returns: List of solfege syllables """ import random # Simple pentatonic scale for low difficulty if difficulty <= 2: # Stepwise motion, no accidentals start_idx = random.randint(0, 4) # Do to Sol exercise = [] for i in range(num_notes): idx = (start_idx + i) % 7 exercise.append(self.SOLFEGE[idx]) return exercise else: # More complex: wider leaps, accidentals exercise = [] current = 0 # Start at Do for _ in range(num_notes): # Jump size increases with difficulty max_jump = min(difficulty + 2, 7) jump = random.randint(-max_jump, max_jump) current = max(0, min(6, current + jump)) exercise.append(self.SOLFEGE[current]) return exercise def generate_interval_quiz( self, num_questions: int = 5, max_interval: int = 12, include_desc: bool = True, ) -> List[Dict]: """ Generate interval identification quiz. Args: num_questions: Number of questions max_interval: Maximum interval size (up to 12) include_desc: Include descriptive hints Returns: List of quiz questions """ import random questions = [] for _ in range(num_questions): interval = random.randint(1, max_interval) quality = "perfect" if interval in [1, 4, 5, 8, 11, 12] else random.choice(["major", "minor"]) question = { "interval_semitones": interval, "interval_name": self.INTERVALS[interval], "quality": quality, } if include_desc: question["hint"] = self.describe_interval(interval, reference="song") questions.append(question) return questions def describe_chord_quality(self, chord_type: str) -> str: """ Describe how a chord quality sounds. Args: chord_type: Chord type (major, minor, etc) Returns: Descriptive string """ description = self.CHORD_DESCRIPTIONS.get(chord_type, "unique sounding") return f"{chord_type} chords sound {description}." def suggest_listening_exercise( self, interval: Optional[int] = None, chord_quality: Optional[str] = None, ) -> Dict[str, str]: """ Suggest specific songs/moments to listen for intervals or chords. Args: interval: Optional specific interval to practice chord_quality: Optional chord quality to practice Returns: Dictionary with listening suggestions """ suggestions = {} if interval: song = self.INTERVAL_SONGS.get(interval, "various songs") suggestions["interval"] = f"Listen for {self.INTERVALS[interval]} in: {song}" suggestions["tip"] = "Try to hum along to internalize the sound." if chord_quality: # Provide famous examples examples = { "major": ["Happy Birthday", "Let It Be (chorus)"], "minor": ["House of the Rising Sun", "Greensleeves"], "diminished": ["The Simpsons theme (tritone)"], "dominant7": ["Blues progressions", "Purple Haze"], "major7": ["Something (The Beatles)", "So What (Miles Davis)"], } songs = examples.get(chord_quality, ["various songs"]) suggestions["chord"] = f"Listen for {chord_quality} chords in: {', '.join(songs)}" suggestions["tip"] = "Focus on the emotional character." return suggestions def track_progress( self, exercise_history: List[Dict], current_performance: float, ) -> Dict[str, any]: """ Track user's progress over session. Args: exercise_history: List of past exercises with scores current_performance: Current success rate (0-1) Returns: Progress analysis """ if not exercise_history: return {"level": "beginner", "suggestion": "Start with interval identification"} # Calculate average performance avg_performance = sum(ex.get("score", 0) for ex in exercise_history) / len(exercise_history) # Determine level if avg_performance < 0.5: level = "beginner" suggestion = "Practice more interval identification with smaller intervals (2nd-5th)." elif avg_performance < 0.7: level = "intermediate" suggestion = "Try more complex intervals and chord qualities." else: level = "advanced" suggestion = "Challenge yourself with inversions and advanced chords." return { "level": level, "average_score": avg_performance, "current_score": current_performance, "suggestion": suggestion, "exercises_completed": len(exercise_history), } def test_ear_training_module(): """Test the EarTrainingModule.""" import torch # Create module module = EarTrainingModule(d_model=4096) # Test input batch_size = 2 seq_len = 10 d_model = 4096 hidden_states = torch.randn(batch_size, seq_len, d_model) # Forward pass outputs = module.forward(hidden_states) print("Ear Training Module outputs:") for key, value in outputs.items(): print(f" {key}: {value.shape}") # Test interval description print("\nInterval descriptions:") for semitones in [3, 4, 5, 7, 10]: desc = module.describe_interval(semitones, reference="song") print(f" {semitones} semitones: {desc}") # Test solfege exercise print("\nSolfege exercise (C, difficulty 2):") solfege = module.generate_solfege_exercise(key="C", difficulty=2, num_notes=8) print(f" {' '.join(solfege)}") # Test interval quiz print("\nInterval quiz (3 questions):") quiz = module.generate_interval_quiz(num_questions=3) for i, q in enumerate(quiz): print(f" Q{i+1}: {q['interval_name']} ({q['interval_semitones']} semitones)") if 'hint' in q: print(f" Hint: {q['hint']}") # Test chord description print("\nChord quality descriptions:") for chord in ["major", "minor", "diminished", "major7"]: desc = module.describe_chord_quality(chord) print(f" {chord}: {desc}") # Test listening suggestions print("\nListening exercise suggestions:") suggestions = module.suggest_listening_exercise(interval=7, chord_quality="major") for key, value in suggestions.items(): print(f" {key}: {value}") # Test progress tracking print("\nProgress tracking:") history = [ {"exercise": "interval", "score": 0.6}, {"exercise": "interval", "score": 0.7}, {"exercise": "chord", "score": 0.5}, ] progress = module.track_progress(history, current_performance=0.8) for key, value in progress.items(): print(f" {key}: {value}") print("\nEar Training Module test complete!") if __name__ == "__main__": test_ear_training_module()