TouchGrass-7b / tests /test_tab_chord_module.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Tests for Tab & Chord Generation Module.
"""
import pytest
import torch
from TouchGrass.models.tab_chord_module import TabChordModule
class TestTabChordModule:
"""Test suite for TabChordModule."""
def setup_method(self):
"""Set up test fixtures."""
self.d_model = 768
self.batch_size = 4
self.num_strings = 6
self.num_frets = 24
self.module = TabChordModule(d_model=self.d_model, num_strings=self.num_strings, num_frets=self.num_frets)
def test_module_initialization(self):
"""Test that module initializes correctly."""
assert self.module.string_embed.num_embeddings == self.num_strings
assert self.module.fret_embed.num_embeddings == self.num_frets + 2 # +2 for special tokens
assert isinstance(self.module.tab_validator, torch.nn.Sequential)
assert isinstance(self.module.difficulty_head, torch.nn.Linear)
assert self.module.difficulty_head.out_features == 3 # easy, medium, hard
def test_forward_pass(self):
"""Test forward pass with dummy inputs."""
seq_len = 10
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
assert "tab_validator" in output
assert "difficulty" in output
assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)
assert output["difficulty"].shape == (self.batch_size, seq_len, 3)
def test_tab_validator_output_range(self):
"""Test that tab validator outputs are in [0, 1] range."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
validator_output = output["tab_validator"]
assert torch.all(validator_output >= 0)
assert torch.all(validator_output <= 1)
def test_difficulty_head_output(self):
"""Test difficulty head produces logits for 3 classes."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
difficulty_logits = output["difficulty"]
# Check that logits are produced (no specific range expected for logits)
assert difficulty_logits.shape == (self.batch_size, seq_len, 3)
def test_embedding_dimensions(self):
"""Test embedding layer dimensions."""
# String embedding: num_strings -> 64
assert self.module.string_embed.embedding_dim == 64
# Fret embedding: num_frets+2 -> 64
assert self.module.fret_embed.embedding_dim == 64
def test_forward_with_different_seq_lengths(self):
"""Test forward pass with varying sequence lengths."""
for seq_len in [1, 5, 20, 50]:
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
assert output["tab_validator"].shape[1] == seq_len
assert output["difficulty"].shape[1] == seq_len
def test_gradient_flow(self):
"""Test that gradients flow through the module."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
loss = output["tab_validator"].sum() + output["difficulty"].sum()
loss.backward()
assert hidden_states.grad is not None
assert self.module.string_embed.weight.grad is not None
assert self.module.fret_embed.weight.grad is not None
def test_different_batch_sizes(self):
"""Test forward pass with different batch sizes."""
for batch_size in [1, 2, 8, 16]:
seq_len = 10
hidden_states = torch.randn(batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
assert output["tab_validator"].shape[0] == batch_size
assert output["difficulty"].shape[0] == batch_size
def test_special_fret_tokens(self):
"""Test handling of special fret tokens (e.g., mute, open)."""
seq_len = 3
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
# Include special fret indices: 0 for open, 1 for mute
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.tensor([[0, 1, 5], [2, 0, 10], [3, 1, 15], [4, 0, 20]])
output = self.module(hidden_states, string_indices, fret_indices)
assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)
def test_tab_validator_confidence_scores(self):
"""Test that validator produces meaningful confidence scores."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
confidence = output["tab_validator"]
# All confidences should be between 0 and 1
assert torch.all((confidence >= 0) & (confidence <= 1))
if __name__ == "__main__":
pytest.main([__file__, "-v"])