| """
|
| Tests for Song Writing Assistant Module.
|
| """
|
|
|
| import pytest
|
| import torch
|
|
|
| from TouchGrass.models.songwriting_module import SongwritingModule
|
|
|
|
|
| class TestSongwritingModule:
|
| """Test suite for SongwritingModule."""
|
|
|
| def setup_method(self):
|
| """Set up test fixtures."""
|
| self.d_model = 768
|
| self.batch_size = 4
|
| self.module = SongwritingModule(d_model=self.d_model)
|
|
|
| def test_module_initialization(self):
|
| """Test that module initializes correctly."""
|
| assert isinstance(self.module.chord_embed, torch.nn.Embedding)
|
| assert isinstance(self.module.progression_lstm, torch.nn.LSTM)
|
| assert isinstance(self.module.mood_classifier, torch.nn.Linear)
|
| assert isinstance(self.module.genre_classifier, torch.nn.Linear)
|
| assert isinstance(self.module.lyric_lstm, torch.nn.LSTM)
|
| assert isinstance(self.module.rhyme_detector, torch.nn.Linear)
|
| assert isinstance(self.module.hook_generator, torch.nn.Linear)
|
| assert isinstance(self.module.production_advisor, torch.nn.Linear)
|
|
|
| def test_forward_pass(self):
|
| """Test forward pass with dummy inputs."""
|
| seq_len = 10
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, chord_ids)
|
|
|
| assert "mood" in output
|
| assert "genre" in output
|
| assert "lyrics" in output
|
| assert "hook" in output
|
| assert "production" in output
|
| assert output["mood"].shape == (self.batch_size, seq_len, 8)
|
| assert output["genre"].shape == (self.batch_size, seq_len, 8)
|
| assert output["lyrics"].shape[0] == self.batch_size
|
| assert output["lyrics"].shape[1] == seq_len
|
| assert output["hook"].shape[0] == self.batch_size
|
| assert output["hook"].shape[1] == seq_len
|
| assert output["production"].shape[0] == self.batch_size
|
| assert output["production"].shape[1] == seq_len
|
|
|
| def test_suggest_progression_pop_major(self):
|
| """Test chord progression suggestion for pop in major key."""
|
| progression = self.module.suggest_progression(mood="happy", genre="pop", num_chords=4, key="C")
|
| assert len(progression) == 4
|
|
|
| assert all(isinstance(p, tuple) and len(p) == 2 for p in progression)
|
|
|
| for degree, chord in progression:
|
| assert isinstance(degree, (int, str))
|
| assert isinstance(chord, str)
|
|
|
| def test_suggest_progression_blues_minor(self):
|
| """Test chord progression suggestion for blues in minor key."""
|
| progression = self.module.suggest_progression(mood="sad", genre="blues", num_chords=4, key="A")
|
| assert len(progression) == 4
|
| for degree, chord in progression:
|
| assert isinstance(chord, str)
|
|
|
|
|
| def test_suggest_progression_rock(self):
|
| """Test chord progression suggestion for rock."""
|
| progression = self.module.suggest_progression(mood="energetic", genre="rock", num_chords=4, key="G")
|
| assert len(progression) == 4
|
|
|
| degrees = [d for d, c in progression]
|
| assert len(degrees) == 4
|
|
|
| def test_generate_lyrics_with_rhyme_scheme(self):
|
| """Test lyric generation with rhyme scheme."""
|
| lyrics = self.module.generate_lyrics(theme="love", rhyme_scheme="ABAB", num_lines=4, key="C")
|
| assert "lyrics" in lyrics or "lines" in lyrics
|
| assert "rhyme_scheme" in lyrics or "scheme" in lyrics
|
|
|
| def test_generate_lyrics_verse_structure(self):
|
| """Test lyric generation for verse structure."""
|
| lyrics = self.module.generate_lyrics(theme="heartbreak", rhyme_scheme="AABB", num_lines=4, key="D")
|
| lines = lyrics.get("lyrics", [])
|
| assert len(lines) == 4
|
|
|
| def test_generate_hook(self):
|
| """Test hook generation."""
|
| hook = self.module.generate_hook(theme="freedom", genre="pop", key="F")
|
| assert "hook" in hook or "line" in hook
|
| assert isinstance(hook.get("hook", ""), str)
|
| assert len(hook.get("hook", "")) > 0
|
|
|
| def test_generate_hook_catchy(self):
|
| """Test that hooks are short and memorable."""
|
| hook = self.module.generate_hook(theme="summer", genre="reggae", key="G")
|
| hook_text = hook.get("hook", "")
|
|
|
| assert len(hook_text.split()) <= 20
|
|
|
| def test_suggest_production_elements(self):
|
| """Test production element suggestions."""
|
| production = self.module.suggest_production(genre="electronic", mood="dark", bpm=128)
|
| assert "elements" in production or "suggestions" in production
|
|
|
| elements = production.get("elements", production.get("suggestions", []))
|
| assert len(elements) > 0
|
|
|
| def test_suggest_production_instruments(self):
|
| """Test that production suggestions include instruments."""
|
| production = self.module.suggest_production(genre="rock", mood="loud", bpm=180)
|
| elements = production.get("elements", production.get("suggestions", []))
|
|
|
| all_text = str(elements).lower()
|
| assert any(inst in all_text for inst in ["guitar", "drums", "bass", "vocals"])
|
|
|
| def test_mood_classification(self):
|
| """Test mood classification."""
|
| moods = self.module.get_available_moods()
|
| expected_moods = ["happy", "sad", "energetic", "calm", "angry", "romantic", "mysterious", "nostalgic"]
|
| for mood in expected_moods:
|
| assert mood in moods
|
|
|
| def test_genre_classification(self):
|
| """Test genre classification."""
|
| genres = self.module.get_available_genres()
|
| expected_genres = ["pop", "rock", "blues", "jazz", "country", "electronic", "hiphop", "classical"]
|
| for genre in expected_genres:
|
| assert genre in genres
|
|
|
| def test_progression_mood_consistency(self):
|
| """Test that suggested progressions match the requested mood."""
|
| happy_prog = self.module.suggest_progression(mood="happy", genre="pop", num_chords=4, key="C")
|
| sad_prog = self.module.suggest_progression(mood="sad", genre="pop", num_chords=4, key="C")
|
|
|
| happy_chords = [c for _, c in happy_prog]
|
| sad_chords = [c for _, c in sad_prog]
|
|
|
| assert happy_chords != sad_chords
|
|
|
| def test_progression_genre_consistency(self):
|
| """Test that suggested progressions match the requested genre."""
|
| rock_prog = self.module.suggest_progression(mood="energetic", genre="rock", num_chords=4, key="E")
|
| jazz_prog = self.module.suggest_progression(mood="calm", genre="jazz", num_chords=4, key="E")
|
|
|
| rock_chords = [c for _, c in rock_prog]
|
| jazz_chords = [c for _, c in jazz_prog]
|
| assert rock_chords != jazz_chords
|
|
|
| def test_key_consistency(self):
|
| """Test that progressions are in the requested key."""
|
| for key in ["C", "G", "D", "A", "E", "B", "F#", "F", "Bb", "Eb", "Ab", "Db"]:
|
| progression = self.module.suggest_progression(mood="happy", genre="pop", num_chords=4, key=key)
|
|
|
| for degree, chord in progression:
|
|
|
| assert isinstance(chord, str)
|
|
|
|
|
|
|
| def test_different_num_chords(self):
|
| """Test requesting different numbers of chords."""
|
| for num in [2, 3, 4, 6, 8]:
|
| progression = self.module.suggest_progression(mood="happy", genre="pop", num_chords=num, key="C")
|
| assert len(progression) == num
|
|
|
| def test_lyric_theme_relevance(self):
|
| """Test that generated lyrics relate to the theme."""
|
| themes = ["love", "loss", "freedom", "nature"]
|
| for theme in themes:
|
| lyrics = self.module.generate_lyrics(theme=theme, rhyme_scheme="AABB", num_lines=4, key="C")
|
| lyric_text = str(lyrics.get("lyrics", [])).lower()
|
|
|
|
|
| assert len(lyric_text) > 0
|
|
|
| def test_rhyme_scheme_enforcement(self):
|
| """Test that rhyme scheme is followed."""
|
| schemes = ["AABB", "ABAB", "ABBA", "AAAA"]
|
| for scheme in schemes:
|
| lyrics = self.module.generate_lyrics(theme="joy", rhyme_scheme=scheme, num_lines=4, key="G")
|
| assert "rhyme_scheme" in lyrics or "scheme" in lyrics
|
|
|
| def test_production_tempo_consideration(self):
|
| """Test that production suggestions consider BPM."""
|
| slow_prod = self.module.suggest_production(genre="ambient", mood="calm", bpm=60)
|
| fast_prod = self.module.suggest_production(genre="metal", mood="aggressive", bpm=200)
|
|
|
| slow_text = str(slow_prod).lower()
|
| fast_text = str(fast_prod).lower()
|
|
|
| assert True
|
|
|
| def test_forward_with_empty_sequence(self):
|
| """Test forward pass with empty sequence."""
|
| seq_len = 0
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, chord_ids)
|
|
|
| for key in ["mood", "genre", "lyrics", "hook", "production"]:
|
| assert output[key].shape[0] == self.batch_size
|
| assert output[key].shape[1] == seq_len
|
|
|
| def test_different_batch_sizes(self):
|
| """Test forward pass with different batch sizes."""
|
| for batch_size in [1, 2, 8]:
|
| seq_len = 10
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, chord_ids)
|
| assert output["mood"].shape[0] == batch_size
|
|
|
| def test_gradient_flow(self):
|
| """Test that gradients flow through the module."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
|
| chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, chord_ids)
|
| loss = sum([out.sum() for out in output.values() if isinstance(out, torch.Tensor)])
|
| loss.backward()
|
|
|
| assert hidden_states.grad is not None
|
| assert self.module.chord_embed.weight.grad is not None
|
|
|
| def test_chord_embedding_vocab_size(self):
|
| """Test chord embedding vocabulary size."""
|
|
|
| assert self.module.chord_embed.num_embeddings >= 24
|
|
|
| def test_mood_classifier_output(self):
|
| """Test mood classifier produces logits for all moods."""
|
| seq_len = 1
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, chord_ids)
|
| mood_logits = output["mood"]
|
| assert mood_logits.shape[-1] >= 8
|
|
|
| def test_genre_classifier_output(self):
|
| """Test genre classifier produces logits for all genres."""
|
| seq_len = 1
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, chord_ids)
|
| genre_logits = output["genre"]
|
| assert genre_logits.shape[-1] >= 8
|
|
|
| def test_lyric_lstm_output_shape(self):
|
| """Test lyric LSTM output shape."""
|
| seq_len = 10
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, chord_ids)
|
| lyrics = output["lyrics"]
|
|
|
| assert lyrics.shape[0] == self.batch_size
|
| assert lyrics.shape[1] == seq_len
|
|
|
| def test_hook_generator_output(self):
|
| """Test hook generator output."""
|
| seq_len = 1
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, chord_ids)
|
| hook = output["hook"]
|
| assert hook.shape[0] == self.batch_size
|
| assert hook.shape[1] == seq_len
|
|
|
| def test_production_advisor_output(self):
|
| """Test production advisor output."""
|
| seq_len = 1
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| chord_ids = torch.randint(0, 24, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, chord_ids)
|
| production = output["production"]
|
| assert production.shape[0] == self.batch_size
|
| assert production.shape[1] == seq_len
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"])
|
|
|