File size: 6,653 Bytes
4f0238f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """
Tests for Tab & Chord Generation Module.
"""
import pytest
import torch
from TouchGrass.models.tab_chord_module import TabChordModule
class TestTabChordModule:
"""Test suite for TabChordModule."""
def setup_method(self):
"""Set up test fixtures."""
self.d_model = 768
self.batch_size = 4
self.num_strings = 6
self.num_frets = 24
self.module = TabChordModule(d_model=self.d_model, num_strings=self.num_strings, num_frets=self.num_frets)
def test_module_initialization(self):
"""Test that module initializes correctly."""
assert self.module.string_embed.num_embeddings == self.num_strings
assert self.module.fret_embed.num_embeddings == self.num_frets + 2 # +2 for special tokens
assert isinstance(self.module.tab_validator, torch.nn.Sequential)
assert isinstance(self.module.difficulty_head, torch.nn.Linear)
assert self.module.difficulty_head.out_features == 3 # easy, medium, hard
def test_forward_pass(self):
"""Test forward pass with dummy inputs."""
seq_len = 10
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
assert "tab_validator" in output
assert "difficulty" in output
assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)
assert output["difficulty"].shape == (self.batch_size, seq_len, 3)
def test_tab_validator_output_range(self):
"""Test that tab validator outputs are in [0, 1] range."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
validator_output = output["tab_validator"]
assert torch.all(validator_output >= 0)
assert torch.all(validator_output <= 1)
def test_difficulty_head_output(self):
"""Test difficulty head produces logits for 3 classes."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
difficulty_logits = output["difficulty"]
# Check that logits are produced (no specific range expected for logits)
assert difficulty_logits.shape == (self.batch_size, seq_len, 3)
def test_embedding_dimensions(self):
"""Test embedding layer dimensions."""
# String embedding: num_strings -> 64
assert self.module.string_embed.embedding_dim == 64
# Fret embedding: num_frets+2 -> 64
assert self.module.fret_embed.embedding_dim == 64
def test_forward_with_different_seq_lengths(self):
"""Test forward pass with varying sequence lengths."""
for seq_len in [1, 5, 20, 50]:
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
assert output["tab_validator"].shape[1] == seq_len
assert output["difficulty"].shape[1] == seq_len
def test_gradient_flow(self):
"""Test that gradients flow through the module."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
loss = output["tab_validator"].sum() + output["difficulty"].sum()
loss.backward()
assert hidden_states.grad is not None
assert self.module.string_embed.weight.grad is not None
assert self.module.fret_embed.weight.grad is not None
def test_different_batch_sizes(self):
"""Test forward pass with different batch sizes."""
for batch_size in [1, 2, 8, 16]:
seq_len = 10
hidden_states = torch.randn(batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
assert output["tab_validator"].shape[0] == batch_size
assert output["difficulty"].shape[0] == batch_size
def test_special_fret_tokens(self):
"""Test handling of special fret tokens (e.g., mute, open)."""
seq_len = 3
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
# Include special fret indices: 0 for open, 1 for mute
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.tensor([[0, 1, 5], [2, 0, 10], [3, 1, 15], [4, 0, 20]])
output = self.module(hidden_states, string_indices, fret_indices)
assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)
def test_tab_validator_confidence_scores(self):
"""Test that validator produces meaningful confidence scores."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
output = self.module(hidden_states, string_indices, fret_indices)
confidence = output["tab_validator"]
# All confidences should be between 0 and 1
assert torch.all((confidence >= 0) & (confidence <= 1))
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|