| """
|
| Music Theory Engine for TouchGrass.
|
| Understands music theory relationships, scales, chords, progressions.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, List, Dict, Tuple
|
|
|
|
|
| class MusicTheoryModule(nn.Module):
|
| """
|
| Understands music theory relationships.
|
|
|
| Knows:
|
| - Circle of fifths and key relationships
|
| - Scale degrees and chord functions (I, ii, iii, IV, V, vi, vii°)
|
| - All modes: Ionian, Dorian, Phrygian, Lydian, Mixolydian, Aeolian, Locrian
|
| - Interval relationships (major/minor/perfect/augmented/diminished)
|
| - Chord tensions and extensions (7ths, 9ths, 11ths, 13ths)
|
| - Common progressions (I-IV-V, ii-V-I, I-V-vi-IV, 12-bar blues, etc)
|
| - Voice leading principles
|
| - Modulation techniques
|
| """
|
|
|
|
|
| CHROMATIC_NOTES = ["C", "C#", "D", "Db", "E", "Eb", "F", "F#", "G", "Gb", "A", "Ab", "B", "Bb"]
|
|
|
|
|
|
|
| SCALE_DEGREES = ["I", "ii", "iii", "IV", "V", "vi", "vii°"]
|
|
|
|
|
| CHORD_TYPES = [
|
| "major", "minor", "diminished", "augmented",
|
| "major7", "minor7", "dominant7", "half-dim7", "dim7",
|
| "major9", "minor9", "dominant9",
|
| "sus2", "sus4", "add9", "6", "maj6",
|
| ]
|
|
|
|
|
| MODES = [
|
| "ionian", "dorian", "phrygian", "lydian",
|
| "mixolydian", "aeolian", "locrian"
|
| ]
|
|
|
|
|
| COMMON_PROGRESSIONS = {
|
| "I-IV-V-I": "Classical cadential",
|
| "ii-V-I": "Jazz turnaround",
|
| "I-V-vi-IV": "Pop progression (4-chord)",
|
| "vi-IV-I-V": "Pop variant",
|
| "I-vi-ii-V": "Circle progression",
|
| "I-vi-IV-V": "50s progression",
|
| "IV-V-I": "Plagal cadence",
|
| "V-I": "Authentic cadence",
|
| "12-bar blues": "Blues",
|
| "i-iv-v": "Minor blues",
|
| }
|
|
|
| def __init__(self, d_model: int):
|
| """
|
| Initialize MusicTheoryModule.
|
|
|
| Args:
|
| d_model: Hidden dimension from base model
|
| """
|
| super().__init__()
|
| self.d_model = d_model
|
|
|
|
|
|
|
| self.note_embed = nn.Embedding(48, 128)
|
| self.chord_type_embed = nn.Embedding(15, 128)
|
| self.mode_embed = nn.Embedding(7, 128)
|
| self.key_embed = nn.Embedding(24, 128)
|
|
|
|
|
| self.relationship_proj = nn.Linear(d_model, d_model)
|
|
|
|
|
| self.chord_function_head = nn.Linear(d_model, 3)
|
|
|
|
|
| self.scale_degree_head = nn.Linear(d_model, 7)
|
|
|
|
|
| self.interval_head = nn.Linear(d_model, 14)
|
|
|
|
|
| self.progression_head = nn.Linear(d_model, 7)
|
|
|
|
|
| self.key_detection_head = nn.Linear(d_model, 24)
|
|
|
|
|
| self.mode_classifier = nn.Linear(d_model, 7)
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| query: Optional[str] = None,
|
| ) -> Dict[str, torch.Tensor]:
|
| """
|
| Forward pass through MusicTheoryModule.
|
|
|
| Args:
|
| hidden_states: Base model hidden states [batch, seq_len, d_model]
|
| query: Optional text query about music theory
|
|
|
| Returns:
|
| Dictionary with theory-related predictions
|
| """
|
| batch_size, seq_len, _ = hidden_states.shape
|
|
|
|
|
| pooled = hidden_states.mean(dim=1)
|
|
|
|
|
| chord_function_logits = self.chord_function_head(pooled)
|
|
|
|
|
| scale_degree_logits = self.scale_degree_head(pooled)
|
|
|
|
|
| interval_logits = self.interval_head(pooled)
|
|
|
|
|
| progression_logits = self.progression_head(pooled)
|
|
|
|
|
| key_logits = self.key_detection_head(pooled)
|
|
|
|
|
| mode_logits = self.mode_classifier(pooled)
|
|
|
| outputs = {
|
| "chord_function_logits": chord_function_logits,
|
| "scale_degree_logits": scale_degree_logits,
|
| "interval_logits": interval_logits,
|
| "progression_logits": progression_logits,
|
| "key_logits": key_logits,
|
| "mode_logits": mode_logits,
|
| }
|
|
|
| return outputs
|
|
|
| def get_chord_function(self, scale_degree: str) -> str:
|
| """
|
| Get chord function (tonic, subdominant, dominant).
|
|
|
| Args:
|
| scale_degree: Roman numeral (I, ii, V, etc)
|
|
|
| Returns:
|
| Chord function string
|
| """
|
| tonic = ["I", "vi"]
|
| subdominant = ["ii", "IV", "vi"]
|
| dominant = ["V", "vii°", "iii"]
|
|
|
| if scale_degree in tonic:
|
| return "tonic"
|
| elif scale_degree in subdominant:
|
| return "subdominant"
|
| elif scale_degree in dominant:
|
| return "dominant"
|
| else:
|
| return "unknown"
|
|
|
| def get_scale_from_key(self, key: str, mode: str = "ionian") -> List[str]:
|
| """
|
| Generate scale notes from key and mode.
|
|
|
| Args:
|
| key: Root note (C, D, E, etc)
|
| mode: Mode name (ionian, dorian, etc)
|
|
|
| Returns:
|
| List of notes in the scale
|
| """
|
|
|
| mode_intervals = {
|
| "ionian": [0, 2, 4, 5, 7, 9, 11],
|
| "dorian": [0, 2, 3, 5, 7, 9, 10],
|
| "phrygian": [0, 1, 3, 5, 7, 8, 10],
|
| "lydian": [0, 2, 4, 6, 7, 9, 11],
|
| "mixolydian": [0, 2, 4, 5, 7, 9, 10],
|
| "aeolian": [0, 2, 3, 5, 7, 8, 10],
|
| "locrian": [0, 1, 3, 5, 6, 8, 10],
|
| }
|
|
|
|
|
| note_to_semitone = {
|
| "C": 0, "C#": 1, "Db": 1, "D": 2, "D#": 3, "Eb": 3,
|
| "E": 4, "F": 5, "F#": 6, "Gb": 6, "G": 7, "G#": 8,
|
| "Ab": 8, "A": 9, "A#": 10, "Bb": 10, "B": 11,
|
| }
|
|
|
| if mode not in mode_intervals:
|
| raise ValueError(f"Unknown mode: {mode}")
|
|
|
| root_semitone = note_to_semitone.get(key)
|
| if root_semitone is None:
|
| raise ValueError(f"Unknown key: {key}")
|
|
|
|
|
| intervals = mode_intervals[mode]
|
| scale = []
|
| for interval in intervals:
|
| semitone = (root_semitone + interval) % 12
|
|
|
| note_name = self._semitone_to_note(semitone)
|
| scale.append(note_name)
|
|
|
| return scale
|
|
|
| def _semitone_to_note(self, semitone: int) -> str:
|
| """Convert semitone number to note name."""
|
| semitone_to_note = {
|
| 0: "C", 1: "C#", 2: "D", 3: "Eb", 4: "E", 5: "F",
|
| 6: "F#", 7: "G", 8: "Ab", 9: "A", 10: "Bb", 11: "B",
|
| }
|
| return semitone_to_note[semitone]
|
|
|
| def get_progression_chords(
|
| self,
|
| progression_name: str,
|
| key: str = "C",
|
| ) -> List[Tuple[str, str]]:
|
| """
|
| Get chord progression as list of (degree, chord).
|
|
|
| Args:
|
| progression_name: Name of progression (e.g., "I-IV-V-I")
|
| key: Root key
|
|
|
| Returns:
|
| List of (scale_degree, chord) tuples
|
| """
|
| if progression_name not in self.COMMON_PROGRESSIONS:
|
| raise ValueError(f"Unknown progression: {progression_name}")
|
|
|
|
|
| degrees = progression_name.split("-")
|
|
|
|
|
| scale = self.get_scale_from_key(key, mode="ionian")
|
|
|
| chords = []
|
| for degree in degrees:
|
|
|
| roman_map = {"I": 0, "ii": 1, "iii": 2, "IV": 3, "V": 4, "vi": 5, "vii°": 6}
|
| idx = roman_map.get(degree)
|
| if idx is None:
|
| continue
|
|
|
| root_note = scale[idx]
|
|
|
| if degree in ["ii", "iii", "vi"]:
|
| quality = "minor"
|
| elif degree == "vii°":
|
| quality = "diminished"
|
| else:
|
| quality = "major"
|
|
|
| chord = f"{root_note} {quality}"
|
| chords.append((degree, chord))
|
|
|
| return chords
|
|
|
| def suggest_progression(
|
| self,
|
| mood: str = "happy",
|
| genre: str = "pop",
|
| num_chords: int = 4,
|
| ) -> List[str]:
|
| """
|
| Suggest chord progression based on mood and genre.
|
|
|
| Args:
|
| mood: Emotional mood (happy, sad, tense, etc)
|
| genre: Music genre
|
| num_chords: Number of chords in progression
|
|
|
| Returns:
|
| List of chord names
|
| """
|
|
|
| if mood == "happy" and genre == "pop":
|
| if num_chords == 4:
|
| return ["I", "V", "vi", "IV"]
|
| elif num_chords == 3:
|
| return ["I", "IV", "V"]
|
| elif mood == "sad" or mood == "melancholy":
|
| return ["vi", "IV", "I", "V"]
|
| elif mood == "tense" or mood == "dramatic":
|
| return ["i", "iv", "V", "i"]
|
| elif mood == "jazzy":
|
| return ["ii", "V", "I", "vi"]
|
| else:
|
| return ["I", "IV", "V", "I"]
|
|
|
| return ["I", "IV", "V", "I"]
|
|
|
| def validate_progression(
|
| self,
|
| progression: List[str],
|
| key: str = "C",
|
| ) -> Tuple[bool, List[str]]:
|
| """
|
| Validate chord progression for theoretical correctness.
|
|
|
| Args:
|
| progression: List of Roman numerals or chord names
|
| key: Key center
|
|
|
| Returns:
|
| (is_valid, issues)
|
| """
|
| issues = []
|
|
|
|
|
| scale = self.get_scale_from_key(key, mode="ionian")
|
| scale_notes = [note.rstrip("b#") for note in scale]
|
|
|
| for chord in progression:
|
|
|
| if " " in chord:
|
| root = chord.split(" ")[0]
|
| if root.rstrip("b#") not in scale_notes:
|
| issues.append(f"Chord {chord} has root {root} not in key {key}")
|
|
|
| return len(issues) == 0, issues
|
|
|
|
|
| def test_music_theory_module():
|
| """Test the MusicTheoryModule."""
|
| import torch
|
|
|
|
|
| module = MusicTheoryModule(d_model=4096)
|
|
|
|
|
| batch_size = 2
|
| seq_len = 10
|
| d_model = 4096
|
| hidden_states = torch.randn(batch_size, seq_len, d_model)
|
|
|
|
|
| outputs = module.forward(hidden_states)
|
|
|
| print("Music Theory Module outputs:")
|
| for key, value in outputs.items():
|
| print(f" {key}: {value.shape}")
|
|
|
|
|
| print("\nScale from C ionian:")
|
| scale = module.get_scale_from_key("C", "ionian")
|
| print(f" {scale}")
|
|
|
| print("\nScale from A dorian:")
|
| scale = module.get_scale_from_key("A", "dorian")
|
| print(f" {scale}")
|
|
|
|
|
| print("\nProgression I-V-vi-IV in C:")
|
| chords = module.get_progression_chords("I-V-vi-IV", "C")
|
| for degree, chord in chords:
|
| print(f" {degree}: {chord}")
|
|
|
|
|
| print("\nSuggested progression (happy, pop, 4 chords):")
|
| prog = module.suggest_progression(mood="happy", genre="pop", num_chords=4)
|
| print(f" {prog}")
|
|
|
|
|
| print("\nValidate progression [I, IV, V, I] in C:")
|
| valid, issues = module.validate_progression(["I", "IV", "V", "I"], "C")
|
| print(f" Valid: {valid}")
|
| if issues:
|
| print(f" Issues: {issues}")
|
|
|
| print("\nMusic Theory Module test complete!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_music_theory_module() |