| """
|
| Pytest configuration and shared fixtures for TouchGrass tests.
|
| """
|
|
|
| import pytest
|
| import torch
|
| from pathlib import Path
|
|
|
|
|
| @pytest.fixture(scope="session")
|
| def project_root():
|
| """Return the project root directory."""
|
| return Path(__file__).parent.parent
|
|
|
|
|
| @pytest.fixture(scope="session")
|
| def test_data_dir(project_root):
|
| """Return the test data directory."""
|
| data_dir = project_root / "tests" / "data"
|
| data_dir.mkdir(parents=True, exist_ok=True)
|
| return data_dir
|
|
|
|
|
| @pytest.fixture
|
| def sample_music_tokens():
|
| """Return a list of sample music tokens."""
|
| return [
|
| "[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[PRODUCTION]",
|
| "[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]",
|
| "[EASY]", "[MEDIUM]", "[HARD]",
|
| "[TAB]", "[CHORD]", "[SCALE]", "[INTERVAL]", "[PROGRESSION]",
|
| "[SIMPLIFY]", "[ENCOURAGE]"
|
| ]
|
|
|
|
|
| @pytest.fixture
|
| def sample_qa_pair():
|
| """Return a sample QA pair for testing."""
|
| return {
|
| "category": "guitar",
|
| "messages": [
|
| {"role": "system", "content": "You are a guitar assistant."},
|
| {"role": "user", "content": "How do I play a G major chord?"},
|
| {"role": "assistant", "content": "Place your middle finger on the 3rd fret of the 6th string, index on 2nd fret of 5th string, and ring/pinky on 3rd fret of the 1st and 2nd strings."}
|
| ]
|
| }
|
|
|
|
|
| @pytest.fixture
|
| def mock_tokenizer():
|
| """Create a mock tokenizer for testing."""
|
| class MockTokenizer:
|
| def __init__(self):
|
| self.vocab_size = 32000
|
| self.pad_token_id = 0
|
|
|
| def encode(self, text, **kwargs):
|
|
|
| return [1, 2, 3, 4, 5]
|
|
|
| def decode(self, token_ids, **kwargs):
|
| return "mocked decoded text"
|
|
|
| def add_special_tokens(self, tokens_dict):
|
| self.vocab_size += len(tokens_dict.get("additional_special_tokens", []))
|
|
|
| def add_tokens(self, tokens):
|
| if isinstance(tokens, list):
|
| self.vocab_size += len(tokens)
|
| else:
|
| self.vocab_size += 1
|
|
|
| def convert_tokens_to_ids(self, token):
|
| return 32000 if token.startswith("[") else 1
|
|
|
| return MockTokenizer()
|
|
|
|
|
| @pytest.fixture
|
| def device():
|
| """Return the device to use for tests."""
|
| return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
| @pytest.fixture
|
| def d_model():
|
| """Return the model dimension for tests."""
|
| return 768
|
|
|
|
|
| @pytest.fixture
|
| def batch_size():
|
| """Return the batch size for tests."""
|
| return 4
|
|
|
|
|
| @pytest.fixture
|
| def seq_len():
|
| """Return the sequence length for tests."""
|
| return 10
|
|
|
|
|
| @pytest.fixture
|
| def music_theory_module(device, d_model):
|
| """Create a MusicTheoryModule instance for testing."""
|
| from TouchGrass.models.music_theory_module import MusicTheoryModule
|
| module = MusicTheoryModule(d_model=d_model).to(device)
|
| module.eval()
|
| return module
|
|
|
|
|
| @pytest.fixture
|
| def tab_chord_module(device, d_model):
|
| """Create a TabChordModule instance for testing."""
|
| from TouchGrass.models.tab_chord_module import TabChordModule
|
| module = TabChordModule(d_model=d_model).to(device)
|
| module.eval()
|
| return module
|
|
|
|
|
| @pytest.fixture
|
| def ear_training_module(device, d_model):
|
| """Create an EarTrainingModule instance for testing."""
|
| from TouchGrass.models.ear_training_module import EarTrainingModule
|
| module = EarTrainingModule(d_model=d_model).to(device)
|
| module.eval()
|
| return module
|
|
|
|
|
| @pytest.fixture
|
| def eq_adapter_module(device, d_model):
|
| """Create a MusicEQAdapter instance for testing."""
|
| from TouchGrass.models.eq_adapter import MusicEQAdapter
|
| module = MusicEQAdapter(d_model=d_model).to(device)
|
| module.eval()
|
| return module
|
|
|
|
|
| @pytest.fixture
|
| def songwriting_module(device, d_model):
|
| """Create a SongwritingModule instance for testing."""
|
| from TouchGrass.models.songwriting_module import SongwritingModule
|
| module = SongwritingModule(d_model=d_model).to(device)
|
| module.eval()
|
| return module
|
|
|
|
|
| @pytest.fixture
|
| def music_qa_generator():
|
| """Create a MusicQAGenerator instance for testing."""
|
| from TouchGrass.data.music_qa_generator import MusicQAGenerator
|
| generator = MusicQAGenerator()
|
| return generator
|
|
|
|
|
| @pytest.fixture
|
| def chat_formatter():
|
| """Create a ChatFormatter instance for testing."""
|
| from TouchGrass.data.chat_formatter import ChatFormatter
|
| formatter = ChatFormatter()
|
| return formatter
|
|
|
|
|
| @pytest.fixture
|
| def touchgrass_loss():
|
| """Create a TouchGrassLoss instance for testing."""
|
| from TouchGrass.training.losses import TouchGrassLoss
|
| loss_fn = TouchGrassLoss(lm_loss_weight=1.0, eq_loss_weight=0.1, music_module_loss_weight=0.05)
|
| return loss_fn
|
|
|
|
|
| def pytest_configure(config):
|
| """Configure pytest with custom markers."""
|
| config.addinivalue_line(
|
| "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
|
| )
|
| config.addinivalue_line(
|
| "markers", "integration: marks tests as integration tests"
|
| )
|
| config.addinivalue_line(
|
| "markers", "gpu: marks tests that require GPU"
|
| )
|
|
|
|
|
| def pytest_collection_modifyitems(config, items):
|
| """Modify test collection to add markers based on file names."""
|
| for item in items:
|
| if "test_inference" in item.nodeid:
|
| item.add_marker(pytest.mark.integration)
|
| if "test_trainer" in item.nodeid:
|
| item.add_marker(pytest.mark.slow)
|
|
|