""" Tests for Tab & Chord Generation Module. """ import pytest import torch from TouchGrass.models.tab_chord_module import TabChordModule class TestTabChordModule: """Test suite for TabChordModule.""" def setup_method(self): """Set up test fixtures.""" self.d_model = 768 self.batch_size = 4 self.num_strings = 6 self.num_frets = 24 self.module = TabChordModule(d_model=self.d_model, num_strings=self.num_strings, num_frets=self.num_frets) def test_module_initialization(self): """Test that module initializes correctly.""" assert self.module.string_embed.num_embeddings == self.num_strings assert self.module.fret_embed.num_embeddings == self.num_frets + 2 # +2 for special tokens assert isinstance(self.module.tab_validator, torch.nn.Sequential) assert isinstance(self.module.difficulty_head, torch.nn.Linear) assert self.module.difficulty_head.out_features == 3 # easy, medium, hard def test_forward_pass(self): """Test forward pass with dummy inputs.""" seq_len = 10 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len)) fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len)) output = self.module(hidden_states, string_indices, fret_indices) assert "tab_validator" in output assert "difficulty" in output assert output["tab_validator"].shape == (self.batch_size, seq_len, 1) assert output["difficulty"].shape == (self.batch_size, seq_len, 3) def test_tab_validator_output_range(self): """Test that tab validator outputs are in [0, 1] range.""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len)) fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len)) output = self.module(hidden_states, string_indices, fret_indices) validator_output = output["tab_validator"] assert torch.all(validator_output >= 0) assert torch.all(validator_output <= 1) def test_difficulty_head_output(self): """Test difficulty head produces logits for 3 classes.""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len)) fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len)) output = self.module(hidden_states, string_indices, fret_indices) difficulty_logits = output["difficulty"] # Check that logits are produced (no specific range expected for logits) assert difficulty_logits.shape == (self.batch_size, seq_len, 3) def test_embedding_dimensions(self): """Test embedding layer dimensions.""" # String embedding: num_strings -> 64 assert self.module.string_embed.embedding_dim == 64 # Fret embedding: num_frets+2 -> 64 assert self.module.fret_embed.embedding_dim == 64 def test_forward_with_different_seq_lengths(self): """Test forward pass with varying sequence lengths.""" for seq_len in [1, 5, 20, 50]: hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len)) fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len)) output = self.module(hidden_states, string_indices, fret_indices) assert output["tab_validator"].shape[1] == seq_len assert output["difficulty"].shape[1] == seq_len def test_gradient_flow(self): """Test that gradients flow through the module.""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True) string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len)) fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len)) output = self.module(hidden_states, string_indices, fret_indices) loss = output["tab_validator"].sum() + output["difficulty"].sum() loss.backward() assert hidden_states.grad is not None assert self.module.string_embed.weight.grad is not None assert self.module.fret_embed.weight.grad is not None def test_different_batch_sizes(self): """Test forward pass with different batch sizes.""" for batch_size in [1, 2, 8, 16]: seq_len = 10 hidden_states = torch.randn(batch_size, seq_len, self.d_model) string_indices = torch.randint(0, self.num_strings, (batch_size, seq_len)) fret_indices = torch.randint(0, self.num_frets + 2, (batch_size, seq_len)) output = self.module(hidden_states, string_indices, fret_indices) assert output["tab_validator"].shape[0] == batch_size assert output["difficulty"].shape[0] == batch_size def test_special_fret_tokens(self): """Test handling of special fret tokens (e.g., mute, open).""" seq_len = 3 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) # Include special fret indices: 0 for open, 1 for mute string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len)) fret_indices = torch.tensor([[0, 1, 5], [2, 0, 10], [3, 1, 15], [4, 0, 20]]) output = self.module(hidden_states, string_indices, fret_indices) assert output["tab_validator"].shape == (self.batch_size, seq_len, 1) def test_tab_validator_confidence_scores(self): """Test that validator produces meaningful confidence scores.""" seq_len = 1 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len)) fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len)) output = self.module(hidden_states, string_indices, fret_indices) confidence = output["tab_validator"] # All confidences should be between 0 and 1 assert torch.all((confidence >= 0) & (confidence <= 1)) if __name__ == "__main__": pytest.main([__file__, "-v"])