| | """
|
| | Pytest configuration and shared fixtures for TouchGrass tests.
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| | from pathlib import Path
|
| |
|
| |
|
| | @pytest.fixture(scope="session")
|
| | def project_root():
|
| | """Return the project root directory."""
|
| | return Path(__file__).parent.parent
|
| |
|
| |
|
| | @pytest.fixture(scope="session")
|
| | def test_data_dir(project_root):
|
| | """Return the test data directory."""
|
| | data_dir = project_root / "tests" / "data"
|
| | data_dir.mkdir(parents=True, exist_ok=True)
|
| | return data_dir
|
| |
|
| |
|
| | @pytest.fixture
|
| | def sample_music_tokens():
|
| | """Return a list of sample music tokens."""
|
| | return [
|
| | "[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[PRODUCTION]",
|
| | "[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]",
|
| | "[EASY]", "[MEDIUM]", "[HARD]",
|
| | "[TAB]", "[CHORD]", "[SCALE]", "[INTERVAL]", "[PROGRESSION]",
|
| | "[SIMPLIFY]", "[ENCOURAGE]"
|
| | ]
|
| |
|
| |
|
| | @pytest.fixture
|
| | def sample_qa_pair():
|
| | """Return a sample QA pair for testing."""
|
| | return {
|
| | "category": "guitar",
|
| | "messages": [
|
| | {"role": "system", "content": "You are a guitar assistant."},
|
| | {"role": "user", "content": "How do I play a G major chord?"},
|
| | {"role": "assistant", "content": "Place your middle finger on the 3rd fret of the 6th string, index on 2nd fret of 5th string, and ring/pinky on 3rd fret of the 1st and 2nd strings."}
|
| | ]
|
| | }
|
| |
|
| |
|
| | @pytest.fixture
|
| | def mock_tokenizer():
|
| | """Create a mock tokenizer for testing."""
|
| | class MockTokenizer:
|
| | def __init__(self):
|
| | self.vocab_size = 32000
|
| | self.pad_token_id = 0
|
| |
|
| | def encode(self, text, **kwargs):
|
| |
|
| | return [1, 2, 3, 4, 5]
|
| |
|
| | def decode(self, token_ids, **kwargs):
|
| | return "mocked decoded text"
|
| |
|
| | def add_special_tokens(self, tokens_dict):
|
| | self.vocab_size += len(tokens_dict.get("additional_special_tokens", []))
|
| |
|
| | def add_tokens(self, tokens):
|
| | if isinstance(tokens, list):
|
| | self.vocab_size += len(tokens)
|
| | else:
|
| | self.vocab_size += 1
|
| |
|
| | def convert_tokens_to_ids(self, token):
|
| | return 32000 if token.startswith("[") else 1
|
| |
|
| | return MockTokenizer()
|
| |
|
| |
|
| | @pytest.fixture
|
| | def device():
|
| | """Return the device to use for tests."""
|
| | return "cuda" if torch.cuda.is_available() else "cpu"
|
| |
|
| |
|
| | @pytest.fixture
|
| | def d_model():
|
| | """Return the model dimension for tests."""
|
| | return 768
|
| |
|
| |
|
| | @pytest.fixture
|
| | def batch_size():
|
| | """Return the batch size for tests."""
|
| | return 4
|
| |
|
| |
|
| | @pytest.fixture
|
| | def seq_len():
|
| | """Return the sequence length for tests."""
|
| | return 10
|
| |
|
| |
|
| | @pytest.fixture
|
| | def music_theory_module(device, d_model):
|
| | """Create a MusicTheoryModule instance for testing."""
|
| | from TouchGrass.models.music_theory_module import MusicTheoryModule
|
| | module = MusicTheoryModule(d_model=d_model).to(device)
|
| | module.eval()
|
| | return module
|
| |
|
| |
|
| | @pytest.fixture
|
| | def tab_chord_module(device, d_model):
|
| | """Create a TabChordModule instance for testing."""
|
| | from TouchGrass.models.tab_chord_module import TabChordModule
|
| | module = TabChordModule(d_model=d_model).to(device)
|
| | module.eval()
|
| | return module
|
| |
|
| |
|
| | @pytest.fixture
|
| | def ear_training_module(device, d_model):
|
| | """Create an EarTrainingModule instance for testing."""
|
| | from TouchGrass.models.ear_training_module import EarTrainingModule
|
| | module = EarTrainingModule(d_model=d_model).to(device)
|
| | module.eval()
|
| | return module
|
| |
|
| |
|
| | @pytest.fixture
|
| | def eq_adapter_module(device, d_model):
|
| | """Create a MusicEQAdapter instance for testing."""
|
| | from TouchGrass.models.eq_adapter import MusicEQAdapter
|
| | module = MusicEQAdapter(d_model=d_model).to(device)
|
| | module.eval()
|
| | return module
|
| |
|
| |
|
| | @pytest.fixture
|
| | def songwriting_module(device, d_model):
|
| | """Create a SongwritingModule instance for testing."""
|
| | from TouchGrass.models.songwriting_module import SongwritingModule
|
| | module = SongwritingModule(d_model=d_model).to(device)
|
| | module.eval()
|
| | return module
|
| |
|
| |
|
| | @pytest.fixture
|
| | def music_qa_generator():
|
| | """Create a MusicQAGenerator instance for testing."""
|
| | from TouchGrass.data.music_qa_generator import MusicQAGenerator
|
| | generator = MusicQAGenerator()
|
| | return generator
|
| |
|
| |
|
| | @pytest.fixture
|
| | def chat_formatter():
|
| | """Create a ChatFormatter instance for testing."""
|
| | from TouchGrass.data.chat_formatter import ChatFormatter
|
| | formatter = ChatFormatter()
|
| | return formatter
|
| |
|
| |
|
| | @pytest.fixture
|
| | def touchgrass_loss():
|
| | """Create a TouchGrassLoss instance for testing."""
|
| | from TouchGrass.training.losses import TouchGrassLoss
|
| | loss_fn = TouchGrassLoss(lm_loss_weight=1.0, eq_loss_weight=0.1, music_module_loss_weight=0.05)
|
| | return loss_fn
|
| |
|
| |
|
| | def pytest_configure(config):
|
| | """Configure pytest with custom markers."""
|
| | config.addinivalue_line(
|
| | "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
|
| | )
|
| | config.addinivalue_line(
|
| | "markers", "integration: marks tests as integration tests"
|
| | )
|
| | config.addinivalue_line(
|
| | "markers", "gpu: marks tests that require GPU"
|
| | )
|
| |
|
| |
|
| | def pytest_collection_modifyitems(config, items):
|
| | """Modify test collection to add markers based on file names."""
|
| | for item in items:
|
| | if "test_inference" in item.nodeid:
|
| | item.add_marker(pytest.mark.integration)
|
| | if "test_trainer" in item.nodeid:
|
| | item.add_marker(pytest.mark.slow)
|
| |
|