TouchGrass-3b / models /music_theory_module.py
Zandy-Wandy's picture
Upload 39 files
9071ef9 verified
"""
Music Theory Engine for TouchGrass.
Understands music theory relationships, scales, chords, progressions.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Dict, Tuple
class MusicTheoryModule(nn.Module):
"""
Understands music theory relationships.
Knows:
- Circle of fifths and key relationships
- Scale degrees and chord functions (I, ii, iii, IV, V, vi, vii°)
- All modes: Ionian, Dorian, Phrygian, Lydian, Mixolydian, Aeolian, Locrian
- Interval relationships (major/minor/perfect/augmented/diminished)
- Chord tensions and extensions (7ths, 9ths, 11ths, 13ths)
- Common progressions (I-IV-V, ii-V-I, I-V-vi-IV, 12-bar blues, etc)
- Voice leading principles
- Modulation techniques
"""
# Chromatic notes (C-based)
CHROMATIC_NOTES = ["C", "C#", "D", "Db", "E", "Eb", "F", "F#", "G", "Gb", "A", "Ab", "B", "Bb"]
# Actually 12 notes, but listing enharmonics for flexibility
# Scale degrees in major (Ionian)
SCALE_DEGREES = ["I", "ii", "iii", "IV", "V", "vi", "vii°"]
# Common chord types
CHORD_TYPES = [
"major", "minor", "diminished", "augmented",
"major7", "minor7", "dominant7", "half-dim7", "dim7",
"major9", "minor9", "dominant9",
"sus2", "sus4", "add9", "6", "maj6",
]
# Modes
MODES = [
"ionian", "dorian", "phrygian", "lydian",
"mixolydian", "aeolian", "locrian"
]
# Common progressions (by scale degrees)
COMMON_PROGRESSIONS = {
"I-IV-V-I": "Classical cadential",
"ii-V-I": "Jazz turnaround",
"I-V-vi-IV": "Pop progression (4-chord)",
"vi-IV-I-V": "Pop variant",
"I-vi-ii-V": "Circle progression",
"I-vi-IV-V": "50s progression",
"IV-V-I": "Plagal cadence",
"V-I": "Authentic cadence",
"12-bar blues": "Blues",
"i-iv-v": "Minor blues",
}
def __init__(self, d_model: int):
"""
Initialize MusicTheoryModule.
Args:
d_model: Hidden dimension from base model
"""
super().__init__()
self.d_model = d_model
# Embeddings
# 12 chromatic notes × 4 octave context = 48 total pitch classes
self.note_embed = nn.Embedding(48, 128) # 12 notes × 4 octaves
self.chord_type_embed = nn.Embedding(15, 128)
self.mode_embed = nn.Embedding(7, 128)
self.key_embed = nn.Embedding(24, 128) # 12 major + 12 minor keys
# Theory relationship head
self.relationship_proj = nn.Linear(d_model, d_model)
# Chord function classifier (tonic, subdominant, dominant)
self.chord_function_head = nn.Linear(d_model, 3)
# Scale degree predictor
self.scale_degree_head = nn.Linear(d_model, 7)
# Interval classifier (unison through 13th)
self.interval_head = nn.Linear(d_model, 14)
# Progression predictor (next chord in progression)
self.progression_head = nn.Linear(d_model, 7)
# Key detection head
self.key_detection_head = nn.Linear(d_model, 24)
# Mode classifier
self.mode_classifier = nn.Linear(d_model, 7)
def forward(
self,
hidden_states: torch.Tensor,
query: Optional[str] = None,
) -> Dict[str, torch.Tensor]:
"""
Forward pass through MusicTheoryModule.
Args:
hidden_states: Base model hidden states [batch, seq_len, d_model]
query: Optional text query about music theory
Returns:
Dictionary with theory-related predictions
"""
batch_size, seq_len, _ = hidden_states.shape
# Pool hidden states
pooled = hidden_states.mean(dim=1) # [batch, d_model]
# Predict chord function
chord_function_logits = self.chord_function_head(pooled) # [batch, 3]
# Predict scale degree
scale_degree_logits = self.scale_degree_head(pooled) # [batch, 7]
# Predict interval
interval_logits = self.interval_head(pooled) # [batch, 14]
# Predict next chord in progression
progression_logits = self.progression_head(pooled) # [batch, 7]
# Detect key
key_logits = self.key_detection_head(pooled) # [batch, 24]
# Classify mode
mode_logits = self.mode_classifier(pooled) # [batch, 7]
outputs = {
"chord_function_logits": chord_function_logits,
"scale_degree_logits": scale_degree_logits,
"interval_logits": interval_logits,
"progression_logits": progression_logits,
"key_logits": key_logits,
"mode_logits": mode_logits,
}
return outputs
def get_chord_function(self, scale_degree: str) -> str:
"""
Get chord function (tonic, subdominant, dominant).
Args:
scale_degree: Roman numeral (I, ii, V, etc)
Returns:
Chord function string
"""
tonic = ["I", "vi"]
subdominant = ["ii", "IV", "vi"]
dominant = ["V", "vii°", "iii"]
if scale_degree in tonic:
return "tonic"
elif scale_degree in subdominant:
return "subdominant"
elif scale_degree in dominant:
return "dominant"
else:
return "unknown"
def get_scale_from_key(self, key: str, mode: str = "ionian") -> List[str]:
"""
Generate scale notes from key and mode.
Args:
key: Root note (C, D, E, etc)
mode: Mode name (ionian, dorian, etc)
Returns:
List of notes in the scale
"""
# Define intervals for each mode (semitones from root)
mode_intervals = {
"ionian": [0, 2, 4, 5, 7, 9, 11],
"dorian": [0, 2, 3, 5, 7, 9, 10],
"phrygian": [0, 1, 3, 5, 7, 8, 10],
"lydian": [0, 2, 4, 6, 7, 9, 11],
"mixolydian": [0, 2, 4, 5, 7, 9, 10],
"aeolian": [0, 2, 3, 5, 7, 8, 10],
"locrian": [0, 1, 3, 5, 6, 8, 10],
}
# Note to semitone mapping (C=0)
note_to_semitone = {
"C": 0, "C#": 1, "Db": 1, "D": 2, "D#": 3, "Eb": 3,
"E": 4, "F": 5, "F#": 6, "Gb": 6, "G": 7, "G#": 8,
"Ab": 8, "A": 9, "A#": 10, "Bb": 10, "B": 11,
}
if mode not in mode_intervals:
raise ValueError(f"Unknown mode: {mode}")
root_semitone = note_to_semitone.get(key)
if root_semitone is None:
raise ValueError(f"Unknown key: {key}")
# Build scale
intervals = mode_intervals[mode]
scale = []
for interval in intervals:
semitone = (root_semitone + interval) % 12
# Find note name
note_name = self._semitone_to_note(semitone)
scale.append(note_name)
return scale
def _semitone_to_note(self, semitone: int) -> str:
"""Convert semitone number to note name."""
semitone_to_note = {
0: "C", 1: "C#", 2: "D", 3: "Eb", 4: "E", 5: "F",
6: "F#", 7: "G", 8: "Ab", 9: "A", 10: "Bb", 11: "B",
}
return semitone_to_note[semitone]
def get_progression_chords(
self,
progression_name: str,
key: str = "C",
) -> List[Tuple[str, str]]:
"""
Get chord progression as list of (degree, chord).
Args:
progression_name: Name of progression (e.g., "I-IV-V-I")
key: Root key
Returns:
List of (scale_degree, chord) tuples
"""
if progression_name not in self.COMMON_PROGRESSIONS:
raise ValueError(f"Unknown progression: {progression_name}")
# Parse progression degrees
degrees = progression_name.split("-")
# Get scale for key
scale = self.get_scale_from_key(key, mode="ionian")
chords = []
for degree in degrees:
# Convert Roman numeral to scale index
roman_map = {"I": 0, "ii": 1, "iii": 2, "IV": 3, "V": 4, "vi": 5, "vii°": 6}
idx = roman_map.get(degree)
if idx is None:
continue
root_note = scale[idx]
# Determine chord quality based on degree
if degree in ["ii", "iii", "vi"]:
quality = "minor"
elif degree == "vii°":
quality = "diminished"
else:
quality = "major"
chord = f"{root_note} {quality}"
chords.append((degree, chord))
return chords
def suggest_progression(
self,
mood: str = "happy",
genre: str = "pop",
num_chords: int = 4,
) -> List[str]:
"""
Suggest chord progression based on mood and genre.
Args:
mood: Emotional mood (happy, sad, tense, etc)
genre: Music genre
num_chords: Number of chords in progression
Returns:
List of chord names
"""
# Simple rule-based suggestions
if mood == "happy" and genre == "pop":
if num_chords == 4:
return ["I", "V", "vi", "IV"]
elif num_chords == 3:
return ["I", "IV", "V"]
elif mood == "sad" or mood == "melancholy":
return ["vi", "IV", "I", "V"]
elif mood == "tense" or mood == "dramatic":
return ["i", "iv", "V", "i"] # Minor with dominant
elif mood == "jazzy":
return ["ii", "V", "I", "vi"]
else:
return ["I", "IV", "V", "I"] # Default
return ["I", "IV", "V", "I"]
def validate_progression(
self,
progression: List[str],
key: str = "C",
) -> Tuple[bool, List[str]]:
"""
Validate chord progression for theoretical correctness.
Args:
progression: List of Roman numerals or chord names
key: Key center
Returns:
(is_valid, issues)
"""
issues = []
# Check if all chords belong to the key
scale = self.get_scale_from_key(key, mode="ionian")
scale_notes = [note.rstrip("b#") for note in scale] # Simplified
for chord in progression:
# Extract root note from chord name
if " " in chord:
root = chord.split(" ")[0]
if root.rstrip("b#") not in scale_notes:
issues.append(f"Chord {chord} has root {root} not in key {key}")
return len(issues) == 0, issues
def test_music_theory_module():
"""Test the MusicTheoryModule."""
import torch
# Create module
module = MusicTheoryModule(d_model=4096)
# Test input
batch_size = 2
seq_len = 10
d_model = 4096
hidden_states = torch.randn(batch_size, seq_len, d_model)
# Forward pass
outputs = module.forward(hidden_states)
print("Music Theory Module outputs:")
for key, value in outputs.items():
print(f" {key}: {value.shape}")
# Test scale generation
print("\nScale from C ionian:")
scale = module.get_scale_from_key("C", "ionian")
print(f" {scale}")
print("\nScale from A dorian:")
scale = module.get_scale_from_key("A", "dorian")
print(f" {scale}")
# Test progression
print("\nProgression I-V-vi-IV in C:")
chords = module.get_progression_chords("I-V-vi-IV", "C")
for degree, chord in chords:
print(f" {degree}: {chord}")
# Test suggestion
print("\nSuggested progression (happy, pop, 4 chords):")
prog = module.suggest_progression(mood="happy", genre="pop", num_chords=4)
print(f" {prog}")
# Test validation
print("\nValidate progression [I, IV, V, I] in C:")
valid, issues = module.validate_progression(["I", "IV", "V", "I"], "C")
print(f" Valid: {valid}")
if issues:
print(f" Issues: {issues}")
print("\nMusic Theory Module test complete!")
if __name__ == "__main__":
test_music_theory_module()