File size: 6,653 Bytes
4f0238f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""

Tests for Tab & Chord Generation Module.

"""

import pytest
import torch

from TouchGrass.models.tab_chord_module import TabChordModule


class TestTabChordModule:
    """Test suite for TabChordModule."""

    def setup_method(self):
        """Set up test fixtures."""
        self.d_model = 768
        self.batch_size = 4
        self.num_strings = 6
        self.num_frets = 24
        self.module = TabChordModule(d_model=self.d_model, num_strings=self.num_strings, num_frets=self.num_frets)

    def test_module_initialization(self):
        """Test that module initializes correctly."""
        assert self.module.string_embed.num_embeddings == self.num_strings
        assert self.module.fret_embed.num_embeddings == self.num_frets + 2  # +2 for special tokens
        assert isinstance(self.module.tab_validator, torch.nn.Sequential)
        assert isinstance(self.module.difficulty_head, torch.nn.Linear)
        assert self.module.difficulty_head.out_features == 3  # easy, medium, hard

    def test_forward_pass(self):
        """Test forward pass with dummy inputs."""
        seq_len = 10
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
        string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
        fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))

        output = self.module(hidden_states, string_indices, fret_indices)

        assert "tab_validator" in output
        assert "difficulty" in output
        assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)
        assert output["difficulty"].shape == (self.batch_size, seq_len, 3)

    def test_tab_validator_output_range(self):
        """Test that tab validator outputs are in [0, 1] range."""
        seq_len = 5
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
        string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
        fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))

        output = self.module(hidden_states, string_indices, fret_indices)
        validator_output = output["tab_validator"]

        assert torch.all(validator_output >= 0)
        assert torch.all(validator_output <= 1)

    def test_difficulty_head_output(self):
        """Test difficulty head produces logits for 3 classes."""
        seq_len = 5
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
        string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
        fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))

        output = self.module(hidden_states, string_indices, fret_indices)
        difficulty_logits = output["difficulty"]

        # Check that logits are produced (no specific range expected for logits)
        assert difficulty_logits.shape == (self.batch_size, seq_len, 3)

    def test_embedding_dimensions(self):
        """Test embedding layer dimensions."""
        # String embedding: num_strings -> 64
        assert self.module.string_embed.embedding_dim == 64
        # Fret embedding: num_frets+2 -> 64
        assert self.module.fret_embed.embedding_dim == 64

    def test_forward_with_different_seq_lengths(self):
        """Test forward pass with varying sequence lengths."""
        for seq_len in [1, 5, 20, 50]:
            hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
            string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
            fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))

            output = self.module(hidden_states, string_indices, fret_indices)
            assert output["tab_validator"].shape[1] == seq_len
            assert output["difficulty"].shape[1] == seq_len

    def test_gradient_flow(self):
        """Test that gradients flow through the module."""
        seq_len = 5
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
        string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
        fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))

        output = self.module(hidden_states, string_indices, fret_indices)
        loss = output["tab_validator"].sum() + output["difficulty"].sum()
        loss.backward()

        assert hidden_states.grad is not None
        assert self.module.string_embed.weight.grad is not None
        assert self.module.fret_embed.weight.grad is not None

    def test_different_batch_sizes(self):
        """Test forward pass with different batch sizes."""
        for batch_size in [1, 2, 8, 16]:
            seq_len = 10
            hidden_states = torch.randn(batch_size, seq_len, self.d_model)
            string_indices = torch.randint(0, self.num_strings, (batch_size, seq_len))
            fret_indices = torch.randint(0, self.num_frets + 2, (batch_size, seq_len))

            output = self.module(hidden_states, string_indices, fret_indices)
            assert output["tab_validator"].shape[0] == batch_size
            assert output["difficulty"].shape[0] == batch_size

    def test_special_fret_tokens(self):
        """Test handling of special fret tokens (e.g., mute, open)."""
        seq_len = 3
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
        # Include special fret indices: 0 for open, 1 for mute
        string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
        fret_indices = torch.tensor([[0, 1, 5], [2, 0, 10], [3, 1, 15], [4, 0, 20]])

        output = self.module(hidden_states, string_indices, fret_indices)
        assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)

    def test_tab_validator_confidence_scores(self):
        """Test that validator produces meaningful confidence scores."""
        seq_len = 1
        hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
        string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
        fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))

        output = self.module(hidden_states, string_indices, fret_indices)
        confidence = output["tab_validator"]

        # All confidences should be between 0 and 1
        assert torch.all((confidence >= 0) & (confidence <= 1))


if __name__ == "__main__":
    pytest.main([__file__, "-v"])