| | """
|
| | Tests for Tab & Chord Generation Module.
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| |
|
| | from TouchGrass.models.tab_chord_module import TabChordModule
|
| |
|
| |
|
| | class TestTabChordModule:
|
| | """Test suite for TabChordModule."""
|
| |
|
| | def setup_method(self):
|
| | """Set up test fixtures."""
|
| | self.d_model = 768
|
| | self.batch_size = 4
|
| | self.num_strings = 6
|
| | self.num_frets = 24
|
| | self.module = TabChordModule(d_model=self.d_model, num_strings=self.num_strings, num_frets=self.num_frets)
|
| |
|
| | def test_module_initialization(self):
|
| | """Test that module initializes correctly."""
|
| | assert self.module.string_embed.num_embeddings == self.num_strings
|
| | assert self.module.fret_embed.num_embeddings == self.num_frets + 2
|
| | assert isinstance(self.module.tab_validator, torch.nn.Sequential)
|
| | assert isinstance(self.module.difficulty_head, torch.nn.Linear)
|
| | assert self.module.difficulty_head.out_features == 3
|
| |
|
| | def test_forward_pass(self):
|
| | """Test forward pass with dummy inputs."""
|
| | seq_len = 10
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| | fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, string_indices, fret_indices)
|
| |
|
| | assert "tab_validator" in output
|
| | assert "difficulty" in output
|
| | assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)
|
| | assert output["difficulty"].shape == (self.batch_size, seq_len, 3)
|
| |
|
| | def test_tab_validator_output_range(self):
|
| | """Test that tab validator outputs are in [0, 1] range."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| | fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, string_indices, fret_indices)
|
| | validator_output = output["tab_validator"]
|
| |
|
| | assert torch.all(validator_output >= 0)
|
| | assert torch.all(validator_output <= 1)
|
| |
|
| | def test_difficulty_head_output(self):
|
| | """Test difficulty head produces logits for 3 classes."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| | fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, string_indices, fret_indices)
|
| | difficulty_logits = output["difficulty"]
|
| |
|
| |
|
| | assert difficulty_logits.shape == (self.batch_size, seq_len, 3)
|
| |
|
| | def test_embedding_dimensions(self):
|
| | """Test embedding layer dimensions."""
|
| |
|
| | assert self.module.string_embed.embedding_dim == 64
|
| |
|
| | assert self.module.fret_embed.embedding_dim == 64
|
| |
|
| | def test_forward_with_different_seq_lengths(self):
|
| | """Test forward pass with varying sequence lengths."""
|
| | for seq_len in [1, 5, 20, 50]:
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| | fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, string_indices, fret_indices)
|
| | assert output["tab_validator"].shape[1] == seq_len
|
| | assert output["difficulty"].shape[1] == seq_len
|
| |
|
| | def test_gradient_flow(self):
|
| | """Test that gradients flow through the module."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
|
| | string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| | fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, string_indices, fret_indices)
|
| | loss = output["tab_validator"].sum() + output["difficulty"].sum()
|
| | loss.backward()
|
| |
|
| | assert hidden_states.grad is not None
|
| | assert self.module.string_embed.weight.grad is not None
|
| | assert self.module.fret_embed.weight.grad is not None
|
| |
|
| | def test_different_batch_sizes(self):
|
| | """Test forward pass with different batch sizes."""
|
| | for batch_size in [1, 2, 8, 16]:
|
| | seq_len = 10
|
| | hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
| | string_indices = torch.randint(0, self.num_strings, (batch_size, seq_len))
|
| | fret_indices = torch.randint(0, self.num_frets + 2, (batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, string_indices, fret_indices)
|
| | assert output["tab_validator"].shape[0] == batch_size
|
| | assert output["difficulty"].shape[0] == batch_size
|
| |
|
| | def test_special_fret_tokens(self):
|
| | """Test handling of special fret tokens (e.g., mute, open)."""
|
| | seq_len = 3
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| | string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| | fret_indices = torch.tensor([[0, 1, 5], [2, 0, 10], [3, 1, 15], [4, 0, 20]])
|
| |
|
| | output = self.module(hidden_states, string_indices, fret_indices)
|
| | assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)
|
| |
|
| | def test_tab_validator_confidence_scores(self):
|
| | """Test that validator produces meaningful confidence scores."""
|
| | seq_len = 1
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| | fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
| |
|
| | output = self.module(hidden_states, string_indices, fret_indices)
|
| | confidence = output["tab_validator"]
|
| |
|
| |
|
| | assert torch.all((confidence >= 0) & (confidence <= 1))
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | pytest.main([__file__, "-v"])
|
| |
|