""" Pytest configuration and shared fixtures for TouchGrass tests. """ import pytest import torch from pathlib import Path @pytest.fixture(scope="session") def project_root(): """Return the project root directory.""" return Path(__file__).parent.parent @pytest.fixture(scope="session") def test_data_dir(project_root): """Return the test data directory.""" data_dir = project_root / "tests" / "data" data_dir.mkdir(parents=True, exist_ok=True) return data_dir @pytest.fixture def sample_music_tokens(): """Return a list of sample music tokens.""" return [ "[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[PRODUCTION]", "[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]", "[EASY]", "[MEDIUM]", "[HARD]", "[TAB]", "[CHORD]", "[SCALE]", "[INTERVAL]", "[PROGRESSION]", "[SIMPLIFY]", "[ENCOURAGE]" ] @pytest.fixture def sample_qa_pair(): """Return a sample QA pair for testing.""" return { "category": "guitar", "messages": [ {"role": "system", "content": "You are a guitar assistant."}, {"role": "user", "content": "How do I play a G major chord?"}, {"role": "assistant", "content": "Place your middle finger on the 3rd fret of the 6th string, index on 2nd fret of 5th string, and ring/pinky on 3rd fret of the 1st and 2nd strings."} ] } @pytest.fixture def mock_tokenizer(): """Create a mock tokenizer for testing.""" class MockTokenizer: def __init__(self): self.vocab_size = 32000 self.pad_token_id = 0 def encode(self, text, **kwargs): # Simple mock encoding return [1, 2, 3, 4, 5] def decode(self, token_ids, **kwargs): return "mocked decoded text" def add_special_tokens(self, tokens_dict): self.vocab_size += len(tokens_dict.get("additional_special_tokens", [])) def add_tokens(self, tokens): if isinstance(tokens, list): self.vocab_size += len(tokens) else: self.vocab_size += 1 def convert_tokens_to_ids(self, token): return 32000 if token.startswith("[") else 1 return MockTokenizer() @pytest.fixture def device(): """Return the device to use for tests.""" return "cuda" if torch.cuda.is_available() else "cpu" @pytest.fixture def d_model(): """Return the model dimension for tests.""" return 768 @pytest.fixture def batch_size(): """Return the batch size for tests.""" return 4 @pytest.fixture def seq_len(): """Return the sequence length for tests.""" return 10 @pytest.fixture def music_theory_module(device, d_model): """Create a MusicTheoryModule instance for testing.""" from TouchGrass.models.music_theory_module import MusicTheoryModule module = MusicTheoryModule(d_model=d_model).to(device) module.eval() return module @pytest.fixture def tab_chord_module(device, d_model): """Create a TabChordModule instance for testing.""" from TouchGrass.models.tab_chord_module import TabChordModule module = TabChordModule(d_model=d_model).to(device) module.eval() return module @pytest.fixture def ear_training_module(device, d_model): """Create an EarTrainingModule instance for testing.""" from TouchGrass.models.ear_training_module import EarTrainingModule module = EarTrainingModule(d_model=d_model).to(device) module.eval() return module @pytest.fixture def eq_adapter_module(device, d_model): """Create a MusicEQAdapter instance for testing.""" from TouchGrass.models.eq_adapter import MusicEQAdapter module = MusicEQAdapter(d_model=d_model).to(device) module.eval() return module @pytest.fixture def songwriting_module(device, d_model): """Create a SongwritingModule instance for testing.""" from TouchGrass.models.songwriting_module import SongwritingModule module = SongwritingModule(d_model=d_model).to(device) module.eval() return module @pytest.fixture def music_qa_generator(): """Create a MusicQAGenerator instance for testing.""" from TouchGrass.data.music_qa_generator import MusicQAGenerator generator = MusicQAGenerator() return generator @pytest.fixture def chat_formatter(): """Create a ChatFormatter instance for testing.""" from TouchGrass.data.chat_formatter import ChatFormatter formatter = ChatFormatter() return formatter @pytest.fixture def touchgrass_loss(): """Create a TouchGrassLoss instance for testing.""" from TouchGrass.training.losses import TouchGrassLoss loss_fn = TouchGrassLoss(lm_loss_weight=1.0, eq_loss_weight=0.1, music_module_loss_weight=0.05) return loss_fn def pytest_configure(config): """Configure pytest with custom markers.""" config.addinivalue_line( "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')" ) config.addinivalue_line( "markers", "integration: marks tests as integration tests" ) config.addinivalue_line( "markers", "gpu: marks tests that require GPU" ) def pytest_collection_modifyitems(config, items): """Modify test collection to add markers based on file names.""" for item in items: if "test_inference" in item.nodeid: item.add_marker(pytest.mark.integration) if "test_trainer" in item.nodeid: item.add_marker(pytest.mark.slow)