| """
|
| Comprehensive evaluation benchmarks for TouchGrass music modules.
|
|
|
| This script evaluates:
|
| 1. Tab & Chord Generation accuracy
|
| 2. Music Theory knowledge
|
| 3. Ear Training interval identification
|
| 4. EQ Adapter emotion detection
|
| 5. Songwriting coherence and creativity
|
| """
|
|
|
| import argparse
|
| import json
|
| import torch
|
| from pathlib import Path
|
| from typing import Dict, List, Any
|
| from tqdm import tqdm
|
|
|
|
|
| from TouchGrass.models.tab_chord_module import TabChordModule
|
| from TouchGrass.models.music_theory_module import MusicTheoryModule
|
| from TouchGrass.models.ear_training_module import EarTrainingModule
|
| from TouchGrass.models.eq_adapter import MusicEQAdapter
|
| from TouchGrass.models.songwriting_module import SongwritingModule
|
|
|
|
|
| class MusicModuleEvaluator:
|
| """Evaluator for all TouchGrass music modules."""
|
|
|
| def __init__(self, device: str = "cpu", d_model: int = 768):
|
| self.device = device
|
| self.d_model = d_model
|
| self.results = {}
|
|
|
|
|
| self.tab_chord = TabChordModule(d_model=d_model).to(device)
|
| self.music_theory = MusicTheoryModule(d_model=d_model).to(device)
|
| self.ear_training = EarTrainingModule(d_model=d_model).to(device)
|
| self.eq_adapter = MusicEQAdapter(d_model=d_model).to(device)
|
| self.songwriting = SongwritingModule(d_model=d_model).to(device)
|
|
|
|
|
| self._set_eval_mode()
|
|
|
| def _set_eval_mode(self):
|
| """Set all modules to evaluation mode."""
|
| self.tab_chord.eval()
|
| self.music_theory.eval()
|
| self.ear_training.eval()
|
| self.eq_adapter.eval()
|
| self.songwriting.eval()
|
|
|
| def evaluate_all(self, test_data_path: str = None) -> Dict[str, Any]:
|
| """Run all evaluations and return comprehensive results."""
|
| print("=" * 60)
|
| print("TouchGrass Music Module Evaluation")
|
| print("=" * 60)
|
|
|
|
|
| self.results["tab_chord"] = self.evaluate_tab_chord()
|
| print(f"✓ Tab & Chord: {self.results['tab_chord']['accuracy']:.2%}")
|
|
|
| self.results["music_theory"] = self.evaluate_music_theory()
|
| print(f"✓ Music Theory: {self.results['music_theory']['accuracy']:.2%}")
|
|
|
| self.results["ear_training"] = self.evaluate_ear_training()
|
| print(f"✓ Ear Training: {self.results['ear_training']['accuracy']:.2%}")
|
|
|
| self.results["eq_adapter"] = self.evaluate_eq_adapter()
|
| print(f"✓ EQ Adapter: {self.results['eq_adapter']['accuracy']:.2%}")
|
|
|
| self.results["songwriting"] = self.evaluate_songwriting()
|
| print(f"✓ Songwriting: {self.results['songwriting']['coherence_score']:.2%}")
|
|
|
|
|
| scores = [
|
| self.results["tab_chord"]["accuracy"],
|
| self.results["music_theory"]["accuracy"],
|
| self.results["ear_training"]["accuracy"],
|
| self.results["eq_adapter"]["accuracy"],
|
| self.results["songwriting"]["coherence_score"]
|
| ]
|
| self.results["overall_score"] = sum(scores) / len(scores)
|
| print(f"\nOverall Score: {self.results['overall_score']:.2%}")
|
|
|
| return self.results
|
|
|
| def evaluate_tab_chord(self) -> Dict[str, Any]:
|
| """Evaluate Tab & Chord Generation module."""
|
| print("\n[1] Evaluating Tab & Chord Module...")
|
|
|
| test_cases = [
|
|
|
| (torch.tensor([[0, 1, 2]]), torch.tensor([[0, 3, 5]]), True),
|
| (torch.tensor([[5, 4, 3, 2, 1, 0]]), torch.tensor([[1, 1, 2, 2, 3, 3]]), True),
|
| (torch.tensor([[0, 0, 0]]), torch.tensor([[0, 0, 0]]), True),
|
| (torch.tensor([[0, 0, 0]]), torch.tensor([[1, 1, 1]]), True),
|
| ]
|
|
|
| correct = 0
|
| total = len(test_cases)
|
|
|
| for string_indices, fret_indices, expected_valid in test_cases:
|
| batch_size, seq_len = string_indices.shape
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
|
|
| with torch.no_grad():
|
| output = self.tab_chord(hidden_states, string_indices, fret_indices)
|
| validator_score = output["tab_validator"].mean().item()
|
|
|
|
|
|
|
| predicted_valid = validator_score > 0.5
|
| if predicted_valid == expected_valid:
|
| correct += 1
|
|
|
| accuracy = correct / total if total > 0 else 0.0
|
|
|
| return {
|
| "accuracy": accuracy,
|
| "correct": correct,
|
| "total": total
|
| }
|
|
|
| def evaluate_music_theory(self) -> Dict[str, Any]:
|
| """Evaluate Music Theory Engine."""
|
| print("\n[2] Evaluating Music Theory Module...")
|
|
|
| tests = [
|
| ("scale_c_major", self._test_scale_c_major),
|
| ("scale_a_minor", self._test_scale_a_minor),
|
| ("chord_functions", self._test_chord_functions),
|
| ("circle_of_fifths", self._test_circle_of_fifths),
|
| ("interval_conversion", self._test_interval_conversion),
|
| ]
|
|
|
| results = {}
|
| for name, test_func in tests:
|
| score = test_func()
|
| results[name] = score
|
| print(f" - {name}: {score:.2%}")
|
|
|
| avg_accuracy = sum(results.values()) / len(results) if results else 0.0
|
| return {
|
| "accuracy": avg_accuracy,
|
| "detailed": results
|
| }
|
|
|
| def _test_scale_c_major(self) -> float:
|
| """Test C major scale generation."""
|
| scale = self.music_theory.get_scale_from_key("C", "major")
|
| expected = ["C", "D", "E", "F", "G", "A", "B"]
|
| return 1.0 if scale == expected else 0.0
|
|
|
| def _test_scale_a_minor(self) -> float:
|
| """Test A natural minor scale."""
|
| scale = self.music_theory.get_scale_from_key("A", "natural_minor")
|
| expected = ["A", "B", "C", "D", "E", "F", "G"]
|
| return 1.0 if scale == expected else 0.0
|
|
|
| def _test_chord_functions(self) -> float:
|
| """Test chord function detection in C major."""
|
| tests = [
|
| ("C", "major", "C", "I"),
|
| ("F", "major", "C", "IV"),
|
| ("G", "major", "C", "V"),
|
| ("D", "minor", "C", "ii"),
|
| ("E", "minor", "C", "iii"),
|
| ("A", "minor", "C", "vi"),
|
| ("B", "dim", "C", "vii°"),
|
| ]
|
|
|
| correct = 0
|
| for root, chord_type, key, expected in tests:
|
| result = self.music_theory.detect_chord_function(root, chord_type, key)
|
| if result == expected:
|
| correct += 1
|
|
|
| return correct / len(tests)
|
|
|
| def _test_circle_of_fifths(self) -> float:
|
| """Test circle of fifths generation."""
|
| circle = self.music_theory.get_circle_of_fifths()
|
|
|
| if len(circle) != 12:
|
| return 0.0
|
|
|
| expected_keys = {"C", "G", "D", "A", "E", "B", "F#", "Db", "Ab", "Eb", "Bb", "F"}
|
| return 1.0 if set(circle) == expected_keys else 0.0
|
|
|
| def _test_interval_conversion(self) -> float:
|
| """Test interval name to semitone conversion."""
|
| tests = [
|
| (0, "P1"), (1, "m2"), (2, "M2"), (3, "m3"), (4, "M3"),
|
| (5, "P4"), (6, "TT"), (7, "P5"), (8, "m6"), (9, "M6"),
|
| (10, "m7"), (11, "M7"), (12, "P8")
|
| ]
|
|
|
| correct = 0
|
| for semitones, expected_name in tests:
|
| name = self.music_theory.semitones_to_interval(semitones)
|
| if name == expected_name:
|
| correct += 1
|
|
|
| return correct / len(tests)
|
|
|
| def evaluate_ear_training(self) -> Dict[str, Any]:
|
| """Evaluate Ear Training module."""
|
| print("\n[3] Evaluating Ear Training Module...")
|
|
|
| tests = [
|
| ("interval_names", self._test_interval_names),
|
| ("interval_to_semitones", self._test_interval_to_semitones),
|
| ("solfege_syllables", self._test_solfege_syllables),
|
| ("song_references", self._test_song_references),
|
| ]
|
|
|
| results = {}
|
| for name, test_func in tests:
|
| score = test_func()
|
| results[name] = score
|
| print(f" - {name}: {score:.2%}")
|
|
|
| avg_accuracy = sum(results.values()) / len(results) if results else 0.0
|
| return {
|
| "accuracy": avg_accuracy,
|
| "detailed": results
|
| }
|
|
|
| def _test_interval_names(self) -> float:
|
| """Test interval name retrieval."""
|
| tests = [
|
| (0, "P1"), (2, "M2"), (4, "M3"), (5, "P4"),
|
| (7, "P5"), (9, "M6"), (11, "M7"), (12, "P8")
|
| ]
|
|
|
| correct = 0
|
| for semitones, expected in tests:
|
| name = self.ear_training.get_interval_name(semitones)
|
| if name == expected:
|
| correct += 1
|
|
|
| return correct / len(tests)
|
|
|
| def _test_interval_to_semitones(self) -> float:
|
| """Test interval name to semitone conversion."""
|
| tests = [
|
| ("P1", 0), ("M2", 2), ("M3", 4), ("P4", 5),
|
| ("P5", 7), ("M6", 9), ("M7", 11), ("P8", 12)
|
| ]
|
|
|
| correct = 0
|
| for name, expected_semitones in tests:
|
| semitones = self.ear_training.name_to_interval(name)
|
| if semitones == expected_semitones:
|
| correct += 1
|
|
|
| return correct / len(tests)
|
|
|
| def _test_solfege_syllables(self) -> float:
|
| """Test solfege syllable generation."""
|
| c_major = self.ear_training.get_solfege_syllables("C", "major")
|
| expected = ["Do", "Re", "Mi", "Fa", "So", "La", "Ti", "Do"]
|
|
|
| return 1.0 if c_major == expected else 0.0
|
|
|
| def _test_song_references(self) -> float:
|
| """Test that song references exist for common intervals."""
|
| common_intervals = ["P5", "M3", "m3", "P4", "M2"]
|
| correct = 0
|
|
|
| for interval in common_intervals:
|
| refs = self.ear_training.get_song_reference(interval)
|
| if len(refs) > 0:
|
| correct += 1
|
|
|
| return correct / len(common_intervals)
|
|
|
| def evaluate_eq_adapter(self) -> Dict[str, Any]:
|
| """Evaluate EQ Adapter emotion detection."""
|
| print("\n[4] Evaluating EQ Adapter...")
|
|
|
| tests = [
|
| ("frustration_range", self._test_frustration_range),
|
| ("emotion_classifier_output", self._test_emotion_classifier),
|
| ("encouragement_output", self._test_encouragement_output),
|
| ("simplification_output", self._test_simplification_output),
|
| ]
|
|
|
| results = {}
|
| for name, test_func in tests:
|
| score = test_func()
|
| results[name] = score
|
| print(f" - {name}: {score:.2%}")
|
|
|
| avg_accuracy = sum(results.values()) / len(results) if results else 0.0
|
| return {
|
| "accuracy": avg_accuracy,
|
| "detailed": results
|
| }
|
|
|
| def _test_frustration_range(self) -> float:
|
| """Test that frustration scores are in [0, 1]."""
|
| batch_size, seq_len = 2, 5
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
|
|
| with torch.no_grad():
|
| output = self.eq_adapter(hidden_states)
|
| frustration = output["frustration"]
|
|
|
|
|
| in_range = ((frustration >= 0) & (frustration <= 1)).all().item()
|
| return 1.0 if in_range else 0.0
|
|
|
| def _test_emotion_classifier(self) -> float:
|
| """Test emotion classifier output shape."""
|
| batch_size, seq_len = 2, 5
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
|
|
| with torch.no_grad():
|
| output = self.eq_adapter(hidden_states)
|
| emotion = output["emotion"]
|
|
|
|
|
| correct_shape = emotion.shape == (batch_size, seq_len, 4)
|
| return 1.0 if correct_shape else 0.0
|
|
|
| def _test_encouragement_output(self) -> float:
|
| """Test that encouragement output is produced."""
|
| batch_size, seq_len = 2, 5
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
|
|
| with torch.no_grad():
|
| output = self.eq_adapter(hidden_states)
|
| has_encouragement = "encouragement" in output
|
| correct_shape = output["encouragement"].shape[0] == batch_size
|
|
|
| return 1.0 if has_encouragement and correct_shape else 0.0
|
|
|
| def _test_simplification_output(self) -> float:
|
| """Test that simplification output matches input shape."""
|
| batch_size, seq_len = 2, 5
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
|
|
| with torch.no_grad():
|
| output = self.eq_adapter(hidden_states)
|
| correct_shape = output["simplification"].shape == hidden_states.shape
|
| return 1.0 if correct_shape else 0.0
|
|
|
| def evaluate_songwriting(self) -> Dict[str, Any]:
|
| """Evaluate Song Writing module."""
|
| print("\n[5] Evaluating Songwriting Module...")
|
|
|
| tests = [
|
| ("progression_generation", self._test_progression_generation),
|
| ("mood_classifier", self._test_mood_classifier),
|
| ("genre_classifier", self._test_genre_classifier),
|
| ("hook_generation", self._test_hook_generation),
|
| ("production_suggestions", self._test_production_suggestions),
|
| ]
|
|
|
| results = {}
|
| for name, test_func in tests:
|
| score = test_func()
|
| results[name] = score
|
| print(f" - {name}: {score:.2%}")
|
|
|
| avg_accuracy = sum(results.values()) / len(results) if results else 0.0
|
| return {
|
| "coherence_score": avg_accuracy,
|
| "detailed": results
|
| }
|
|
|
| def _test_progression_generation(self) -> float:
|
| """Test chord progression generation."""
|
| try:
|
| progression = self.songwriting.suggest_progression(
|
| mood="happy", genre="pop", num_chords=4, key="C"
|
| )
|
|
|
| if not isinstance(progression, list):
|
| return 0.0
|
| if len(progression) != 4:
|
| return 0.0
|
| if not all(isinstance(p, tuple) and len(p) == 2 for p in progression):
|
| return 0.0
|
| return 1.0
|
| except Exception:
|
| return 0.0
|
|
|
| def _test_mood_classifier(self) -> float:
|
| """Test mood classifier output."""
|
| batch_size, seq_len = 2, 5
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (batch_size, seq_len))
|
|
|
| with torch.no_grad():
|
| output = self.songwriting(hidden_states, chord_ids)
|
| mood = output["mood"]
|
|
|
|
|
| correct_shape = mood.shape[-1] >= 8
|
| return 1.0 if correct_shape else 0.0
|
|
|
| def _test_genre_classifier(self) -> float:
|
| """Test genre classifier output."""
|
| batch_size, seq_len = 2, 5
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (batch_size, seq_len))
|
|
|
| with torch.no_grad():
|
| output = self.songwriting(hidden_states, chord_ids)
|
| genre = output["genre"]
|
|
|
|
|
| correct_shape = genre.shape[-1] >= 8
|
| return 1.0 if correct_shape else 0.0
|
|
|
| def _test_hook_generation(self) -> float:
|
| """Test hook generation."""
|
| try:
|
| hook = self.songwriting.generate_hook(
|
| theme="freedom", genre="pop", key="C"
|
| )
|
|
|
| if not isinstance(hook, dict):
|
| return 0.0
|
| if "hook" not in hook:
|
| return 0.0
|
| if not isinstance(hook["hook"], str):
|
| return 0.0
|
| if len(hook["hook"]) == 0:
|
| return 0.0
|
| return 1.0
|
| except Exception:
|
| return 0.0
|
|
|
| def _test_production_suggestions(self) -> float:
|
| """Test production element suggestions."""
|
| try:
|
| production = self.songwriting.suggest_production(
|
| genre="rock", mood="energetic", bpm=120
|
| )
|
|
|
| if not isinstance(production, dict):
|
| return 0.0
|
| has_elements = "elements" in production or "suggestions" in production
|
| return 1.0 if has_elements else 0.0
|
| except Exception:
|
| return 0.0
|
|
|
| def save_results(self, output_path: str):
|
| """Save evaluation results to JSON file."""
|
| output_path = Path(output_path)
|
| output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
| with open(output_path, 'w', encoding='utf-8') as f:
|
| json.dump(self.results, f, indent=2)
|
|
|
| print(f"\n✓ Results saved to {output_path}")
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Evaluate TouchGrass music modules")
|
| parser.add_argument("--device", type=str, default="cpu", help="Device to use (cpu or cuda)")
|
| parser.add_argument("--d_model", type=int, default=768, help="Model dimension")
|
| parser.add_argument("--output", type=str, default="benchmarks/results/music_module_eval.json",
|
| help="Output path for results")
|
| parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| torch.manual_seed(args.seed)
|
|
|
|
|
| evaluator = MusicModuleEvaluator(device=args.device, d_model=args.d_model)
|
|
|
|
|
| results = evaluator.evaluate_all()
|
|
|
|
|
| evaluator.save_results(args.output)
|
|
|
| print("\n" + "=" * 60)
|
| print("Evaluation complete!")
|
| print("=" * 60)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|