TouchGrass-7b / tests /test_songwriting_module.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Tests for Song Writing Assistant Module.
"""
import pytest
import torch
from TouchGrass.models.songwriting_module import SongwritingModule
class TestSongwritingModule:
"""Test suite for SongwritingModule."""
def setup_method(self):
"""Set up test fixtures."""
self.d_model = 768
self.batch_size = 4
self.module = SongwritingModule(d_model=self.d_model)
def test_module_initialization(self):
"""Test that module initializes correctly."""
assert isinstance(self.module.chord_embed, torch.nn.Embedding)
assert isinstance(self.module.progression_lstm, torch.nn.LSTM)
assert isinstance(self.module.mood_classifier, torch.nn.Linear)
assert isinstance(self.module.genre_classifier, torch.nn.Linear)
assert isinstance(self.module.lyric_lstm, torch.nn.LSTM)
assert isinstance(self.module.rhyme_detector, torch.nn.Linear)
assert isinstance(self.module.hook_generator, torch.nn.Linear)
assert isinstance(self.module.production_advisor, torch.nn.Linear)
def test_forward_pass(self):
"""Test forward pass with dummy inputs."""
seq_len = 10
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
chord_ids = torch.randint(0, 24, (self.batch_size, seq_len)) # 24 chords
output = self.module(hidden_states, chord_ids)
assert "mood" in output
assert "genre" in output
assert "lyrics" in output
assert "hook" in output
assert "production" in output
assert output["mood"].shape == (self.batch_size, seq_len, 8) # 8 moods
assert output["genre"].shape == (self.batch_size, seq_len, 8) # 8 genres
assert output["lyrics"].shape[0] == self.batch_size
assert output["lyrics"].shape[1] == seq_len
assert output["hook"].shape[0] == self.batch_size
assert output["hook"].shape[1] == seq_len
assert output["production"].shape[0] == self.batch_size
assert output["production"].shape[1] == seq_len
def test_suggest_progression_pop_major(self):
"""Test chord progression suggestion for pop in major key."""
progression = self.module.suggest_progression(mood="happy", genre="pop", num_chords=4, key="C")
assert len(progression) == 4
# Each element should be (degree, chord) tuple
assert all(isinstance(p, tuple) and len(p) == 2 for p in progression)
# Check that chords are in C major key
for degree, chord in progression:
assert isinstance(degree, (int, str))
assert isinstance(chord, str)
def test_suggest_progression_blues_minor(self):
"""Test chord progression suggestion for blues in minor key."""
progression = self.module.suggest_progression(mood="sad", genre="blues", num_chords=4, key="A")
assert len(progression) == 4
for degree, chord in progression:
assert isinstance(chord, str)
# Should have minor or dominant 7th chords typical of blues
def test_suggest_progression_rock(self):
"""Test chord progression suggestion for rock."""
progression = self.module.suggest_progression(mood="energetic", genre="rock", num_chords=4, key="G")
assert len(progression) == 4
# Rock often uses power chords (5ths) and simple progressions
degrees = [d for d, c in progression]
assert len(degrees) == 4
def test_generate_lyrics_with_rhyme_scheme(self):
"""Test lyric generation with rhyme scheme."""
lyrics = self.module.generate_lyrics(theme="love", rhyme_scheme="ABAB", num_lines=4, key="C")
assert "lyrics" in lyrics or "lines" in lyrics
assert "rhyme_scheme" in lyrics or "scheme" in lyrics
def test_generate_lyrics_verse_structure(self):
"""Test lyric generation for verse structure."""
lyrics = self.module.generate_lyrics(theme="heartbreak", rhyme_scheme="AABB", num_lines=4, key="D")
lines = lyrics.get("lyrics", [])
assert len(lines) == 4
def test_generate_hook(self):
"""Test hook generation."""
hook = self.module.generate_hook(theme="freedom", genre="pop", key="F")
assert "hook" in hook or "line" in hook
assert isinstance(hook.get("hook", ""), str)
assert len(hook.get("hook", "")) > 0
def test_generate_hook_catchy(self):
"""Test that hooks are short and memorable."""
hook = self.module.generate_hook(theme="summer", genre="reggae", key="G")
hook_text = hook.get("hook", "")
# Hooks should be relatively short (typically 1-2 lines)
assert len(hook_text.split()) <= 20
def test_suggest_production_elements(self):
"""Test production element suggestions."""
production = self.module.suggest_production(genre="electronic", mood="dark", bpm=128)
assert "elements" in production or "suggestions" in production
# Should include instruments, effects, or arrangement tips
elements = production.get("elements", production.get("suggestions", []))
assert len(elements) > 0
def test_suggest_production_instruments(self):
"""Test that production suggestions include instruments."""
production = self.module.suggest_production(genre="rock", mood="loud", bpm=180)
elements = production.get("elements", production.get("suggestions", []))
# Should mention instruments like guitar, drums, bass
all_text = str(elements).lower()
assert any(inst in all_text for inst in ["guitar", "drums", "bass", "vocals"])
def test_mood_classification(self):
"""Test mood classification."""
moods = self.module.get_available_moods()
expected_moods = ["happy", "sad", "energetic", "calm", "angry", "romantic", "mysterious", "nostalgic"]
for mood in expected_moods:
assert mood in moods
def test_genre_classification(self):
"""Test genre classification."""
genres = self.module.get_available_genres()
expected_genres = ["pop", "rock", "blues", "jazz", "country", "electronic", "hiphop", "classical"]
for genre in expected_genres:
assert genre in genres
def test_progression_mood_consistency(self):
"""Test that suggested progressions match the requested mood."""
happy_prog = self.module.suggest_progression(mood="happy", genre="pop", num_chords=4, key="C")
sad_prog = self.module.suggest_progression(mood="sad", genre="pop", num_chords=4, key="C")
# Happy progressions typically use major chords, sad use minor
happy_chords = [c for _, c in happy_prog]
sad_chords = [c for _, c in sad_prog]
# At least some difference expected
assert happy_chords != sad_chords
def test_progression_genre_consistency(self):
"""Test that suggested progressions match the requested genre."""
rock_prog = self.module.suggest_progression(mood="energetic", genre="rock", num_chords=4, key="E")
jazz_prog = self.module.suggest_progression(mood="calm", genre="jazz", num_chords=4, key="E")
# Rock and jazz should have different characteristic progressions
rock_chords = [c for _, c in rock_prog]
jazz_chords = [c for _, c in jazz_prog]
assert rock_chords != jazz_chords
def test_key_consistency(self):
"""Test that progressions are in the requested key."""
for key in ["C", "G", "D", "A", "E", "B", "F#", "F", "Bb", "Eb", "Ab", "Db"]:
progression = self.module.suggest_progression(mood="happy", genre="pop", num_chords=4, key=key)
# All chords should be based on the given key
for degree, chord in progression:
# Chord should start with the root note of the key or a diatonic note
assert isinstance(chord, str)
# Basic check: chord should contain the key's root or a note from that key
# (simplified check - in reality would validate diatonicity)
def test_different_num_chords(self):
"""Test requesting different numbers of chords."""
for num in [2, 3, 4, 6, 8]:
progression = self.module.suggest_progression(mood="happy", genre="pop", num_chords=num, key="C")
assert len(progression) == num
def test_lyric_theme_relevance(self):
"""Test that generated lyrics relate to the theme."""
themes = ["love", "loss", "freedom", "nature"]
for theme in themes:
lyrics = self.module.generate_lyrics(theme=theme, rhyme_scheme="AABB", num_lines=4, key="C")
lyric_text = str(lyrics.get("lyrics", [])).lower()
# Lyrics should somehow relate to theme (at least contain theme word or related words)
# This is a basic check; real evaluation would be more sophisticated
assert len(lyric_text) > 0
def test_rhyme_scheme_enforcement(self):
"""Test that rhyme scheme is followed."""
schemes = ["AABB", "ABAB", "ABBA", "AAAA"]
for scheme in schemes:
lyrics = self.module.generate_lyrics(theme="joy", rhyme_scheme=scheme, num_lines=4, key="G")
assert "rhyme_scheme" in lyrics or "scheme" in lyrics
def test_production_tempo_consideration(self):
"""Test that production suggestions consider BPM."""
slow_prod = self.module.suggest_production(genre="ambient", mood="calm", bpm=60)
fast_prod = self.module.suggest_production(genre="metal", mood="aggressive", bpm=200)
# Different tempos should yield different suggestions
slow_text = str(slow_prod).lower()
fast_text = str(fast_prod).lower()
# Not necessarily completely different, but likely some variation
assert True # Placeholder - would need trained model to see actual differences
def test_forward_with_empty_sequence(self):
"""Test forward pass with empty sequence."""
seq_len = 0
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
output = self.module(hidden_states, chord_ids)
# Should handle gracefully
for key in ["mood", "genre", "lyrics", "hook", "production"]:
assert output[key].shape[0] == self.batch_size
assert output[key].shape[1] == seq_len
def test_different_batch_sizes(self):
"""Test forward pass with different batch sizes."""
for batch_size in [1, 2, 8]:
seq_len = 10
hidden_states = torch.randn(batch_size, seq_len, self.d_model)
chord_ids = torch.randint(0, 24, (batch_size, seq_len))
output = self.module(hidden_states, chord_ids)
assert output["mood"].shape[0] == batch_size
def test_gradient_flow(self):
"""Test that gradients flow through the module."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
output = self.module(hidden_states, chord_ids)
loss = sum([out.sum() for out in output.values() if isinstance(out, torch.Tensor)])
loss.backward()
assert hidden_states.grad is not None
assert self.module.chord_embed.weight.grad is not None
def test_chord_embedding_vocab_size(self):
"""Test chord embedding vocabulary size."""
# Should accommodate 24 chords (12 major, 12 minor at minimum)
assert self.module.chord_embed.num_embeddings >= 24
def test_mood_classifier_output(self):
"""Test mood classifier produces logits for all moods."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
output = self.module(hidden_states, chord_ids)
mood_logits = output["mood"]
assert mood_logits.shape[-1] >= 8 # At least 8 moods
def test_genre_classifier_output(self):
"""Test genre classifier produces logits for all genres."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
output = self.module(hidden_states, chord_ids)
genre_logits = output["genre"]
assert genre_logits.shape[-1] >= 8 # At least 8 genres
def test_lyric_lstm_output_shape(self):
"""Test lyric LSTM output shape."""
seq_len = 10
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
output = self.module(hidden_states, chord_ids)
lyrics = output["lyrics"]
# Lyrics should be sequence of token embeddings or logits
assert lyrics.shape[0] == self.batch_size
assert lyrics.shape[1] == seq_len
def test_hook_generator_output(self):
"""Test hook generator output."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
output = self.module(hidden_states, chord_ids)
hook = output["hook"]
assert hook.shape[0] == self.batch_size
assert hook.shape[1] == seq_len
def test_production_advisor_output(self):
"""Test production advisor output."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
output = self.module(hidden_states, chord_ids)
production = output["production"]
assert production.shape[0] == self.batch_size
assert production.shape[1] == seq_len
if __name__ == "__main__":
pytest.main([__file__, "-v"])