TouchGrass-7b / tests /test_ear_training_module.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Tests for Ear Training Module.
"""
import pytest
import torch
from TouchGrass.models.ear_training_module import EarTrainingModule
class TestEarTrainingModule:
"""Test suite for EarTrainingModule."""
def setup_method(self):
"""Set up test fixtures."""
self.d_model = 768
self.batch_size = 4
self.module = EarTrainingModule(d_model=self.d_model)
def test_module_initialization(self):
"""Test that module initializes correctly."""
assert isinstance(self.module.interval_embed, torch.nn.Embedding)
assert isinstance(self.module.interval_classifier, torch.nn.Linear)
assert isinstance(self.module.solfege_embed, torch.nn.Embedding)
assert isinstance(self.module.solfege_generator, torch.nn.LSTM)
assert isinstance(self.module.quiz_lstm, torch.nn.LSTM)
assert isinstance(self.module.quiz_head, torch.nn.Linear)
def test_forward_pass(self):
"""Test forward pass with dummy inputs."""
seq_len = 10
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
interval_ids = torch.randint(0, 12, (self.batch_size, seq_len)) # 12 intervals
output = self.module(hidden_states, interval_ids)
assert "interval_logits" in output
assert "solfege" in output
assert "quiz_questions" in output
assert output["interval_logits"].shape == (self.batch_size, seq_len, 12)
assert output["solfege"].shape[0] == self.batch_size
assert output["solfege"].shape[1] == seq_len
assert output["quiz_questions"].shape[0] == self.batch_size
assert output["quiz_questions"].shape[1] == seq_len
def test_get_interval_name(self):
"""Test interval name retrieval."""
assert self.module.get_interval_name(0) == "P1" # Perfect unison
assert self.module.get_interval_name(2) == "M2" # Major 2nd
assert self.module.get_interval_name(4) == "M3" # Major 3rd
assert self.module.get_interval_name(7) == "P5" # Perfect 5th
assert self.module.get_interval_name(12) == "P8" # Perfect octave
def test_get_song_reference(self):
"""Test song reference retrieval for intervals."""
# Perfect 5th - Star Wars
p5_refs = self.module.get_song_reference("P5")
assert "Star Wars" in p5_refs or "star wars" in p5_refs.lower()
# Minor 2nd - Jaws
m2_refs = self.module.get_song_reference("m2")
assert "Jaws" in m2_refs or "jaws" in m2_refs.lower()
# Major 3rd - When the Saints
M3_refs = self.module.get_song_reference("M3")
assert "Saints" in M3_refs or "saints" in M3_refs.lower()
def test_generate_solfege_exercise(self):
"""Test solfege exercise generation."""
exercise = self.module.generate_solfege_exercise(difficulty="beginner", key="C")
assert "exercise" in exercise or "notes" in exercise
assert "key" in exercise or "C" in str(exercise)
def test_generate_interval_quiz(self):
"""Test interval quiz generation."""
quiz = self.module.generate_interval_quiz(num_questions=5, difficulty="medium")
assert "questions" in quiz
assert len(quiz["questions"]) == 5
def test_describe_interval(self):
"""Test interval description with song reference."""
description = self.module.describe_interval(7) # Perfect 5th
assert "7 semitones" in description or "perfect fifth" in description.lower()
assert "Star Wars" in description or "star wars" in description.lower()
def test_get_solfege_syllables(self):
"""Test solfege syllable retrieval."""
syllables = self.module.get_solfege_syllables(key="C", mode="major")
expected = ["Do", "Re", "Mi", "Fa", "So", "La", "Ti", "Do"]
assert syllables == expected
def test_get_solfege_syllables_minor(self):
"""Test solfege syllables for minor mode."""
syllables = self.module.get_solfege_syllables(key="A", mode="minor")
# Minor solfege: Do Re Me Fa Se Le Te Do (or variations)
assert "Do" in syllables
assert len(syllables) >= 7
def test_interval_to_name(self):
"""Test converting semitone count to interval name."""
assert self.module.interval_to_name(0) == "P1"
assert self.module.interval_to_name(1) == "m2"
assert self.module.interval_to_name(2) == "M2"
assert self.module.interval_to_name(3) == "m3"
assert self.module.interval_to_name(4) == "M3"
assert self.module.interval_to_name(5) == "P4"
assert self.module.interval_to_name(6) == "TT" # Tritone
assert self.module.interval_to_name(7) == "P5"
assert self.module.interval_to_name(11) == "M7"
assert self.module.interval_to_name(12) == "P8"
def test_name_to_interval(self):
"""Test converting interval name to semitone count."""
assert self.module.name_to_interval("P1") == 0
assert self.module.name_to_interval("m2") == 1
assert self.module.name_to_interval("M2") == 2
assert self.module.name_to_interval("M3") == 4
assert self.module.name_to_interval("P4") == 5
assert self.module.name_to_interval("P5") == 7
assert self.module.name_to_interval("P8") == 12
def test_quiz_question_format(self):
"""Test that quiz questions are properly formatted."""
quiz = self.module.generate_interval_quiz(num_questions=3, difficulty="easy")
for question in quiz["questions"]:
assert "question" in question
assert "answer" in question
assert "options" in question or isinstance(question["answer"], (str, int))
def test_solfege_output_length(self):
"""Test solfege output has correct sequence length."""
seq_len = 10
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))
output = self.module(hidden_states, interval_ids)
solfege_seq_len = output["solfege"].shape[1]
assert solfege_seq_len == seq_len
def test_different_batch_sizes(self):
"""Test forward pass with different batch sizes."""
for batch_size in [1, 2, 8]:
seq_len = 10
hidden_states = torch.randn(batch_size, seq_len, self.d_model)
interval_ids = torch.randint(0, 12, (batch_size, seq_len))
output = self.module(hidden_states, interval_ids)
assert output["interval_logits"].shape[0] == batch_size
def test_gradient_flow(self):
"""Test that gradients flow through the module."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))
output = self.module(hidden_states, interval_ids)
loss = output["interval_logits"].sum() + output["solfege"].sum()
loss.backward()
assert hidden_states.grad is not None
assert self.module.interval_embed.weight.grad is not None
def test_interval_classifier_output(self):
"""Test interval classifier produces logits for all intervals."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))
output = self.module(hidden_states, interval_ids)
logits = output["interval_logits"]
# Should have logits for 12 intervals (0-11 semitones)
assert logits.shape[-1] == 12
def test_quiz_head_output(self):
"""Test quiz head produces appropriate output."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
interval_ids = torch.randint(0, 12, (self.batch_size, seq_len))
output = self.module(hidden_states, interval_ids)
quiz_output = output["quiz_questions"]
# Quiz output should have some dimension for question generation
assert quiz_output.shape[0] == self.batch_size
assert quiz_output.shape[1] == seq_len
def test_song_reference_coverage(self):
"""Test that common intervals have song references."""
common_intervals = [0, 2, 4, 5, 7, 9, 12] # P1, M2, M3, P4, P5, M6, P8
for interval in common_intervals:
name = self.module.interval_to_name(interval)
refs = self.module.get_song_reference(name)
assert len(refs) > 0, f"No song reference for interval {name}"
def test_musical_accuracy(self):
"""Test musical accuracy of interval calculations."""
# Test all intervals from 0 to 12
for semitones in range(13):
name = self.module.interval_to_name(semitones)
converted_back = self.module.name_to_interval(name)
assert converted_back == semitones, f"Round-trip failed for {semitones} ({name})"
if __name__ == "__main__":
pytest.main([__file__, "-v"])