TouchGrass-7b / tests /conftest.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Pytest configuration and shared fixtures for TouchGrass tests.
"""
import pytest
import torch
from pathlib import Path
@pytest.fixture(scope="session")
def project_root():
"""Return the project root directory."""
return Path(__file__).parent.parent
@pytest.fixture(scope="session")
def test_data_dir(project_root):
"""Return the test data directory."""
data_dir = project_root / "tests" / "data"
data_dir.mkdir(parents=True, exist_ok=True)
return data_dir
@pytest.fixture
def sample_music_tokens():
"""Return a list of sample music tokens."""
return [
"[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[PRODUCTION]",
"[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]",
"[EASY]", "[MEDIUM]", "[HARD]",
"[TAB]", "[CHORD]", "[SCALE]", "[INTERVAL]", "[PROGRESSION]",
"[SIMPLIFY]", "[ENCOURAGE]"
]
@pytest.fixture
def sample_qa_pair():
"""Return a sample QA pair for testing."""
return {
"category": "guitar",
"messages": [
{"role": "system", "content": "You are a guitar assistant."},
{"role": "user", "content": "How do I play a G major chord?"},
{"role": "assistant", "content": "Place your middle finger on the 3rd fret of the 6th string, index on 2nd fret of 5th string, and ring/pinky on 3rd fret of the 1st and 2nd strings."}
]
}
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer for testing."""
class MockTokenizer:
def __init__(self):
self.vocab_size = 32000
self.pad_token_id = 0
def encode(self, text, **kwargs):
# Simple mock encoding
return [1, 2, 3, 4, 5]
def decode(self, token_ids, **kwargs):
return "mocked decoded text"
def add_special_tokens(self, tokens_dict):
self.vocab_size += len(tokens_dict.get("additional_special_tokens", []))
def add_tokens(self, tokens):
if isinstance(tokens, list):
self.vocab_size += len(tokens)
else:
self.vocab_size += 1
def convert_tokens_to_ids(self, token):
return 32000 if token.startswith("[") else 1
return MockTokenizer()
@pytest.fixture
def device():
"""Return the device to use for tests."""
return "cuda" if torch.cuda.is_available() else "cpu"
@pytest.fixture
def d_model():
"""Return the model dimension for tests."""
return 768
@pytest.fixture
def batch_size():
"""Return the batch size for tests."""
return 4
@pytest.fixture
def seq_len():
"""Return the sequence length for tests."""
return 10
@pytest.fixture
def music_theory_module(device, d_model):
"""Create a MusicTheoryModule instance for testing."""
from TouchGrass.models.music_theory_module import MusicTheoryModule
module = MusicTheoryModule(d_model=d_model).to(device)
module.eval()
return module
@pytest.fixture
def tab_chord_module(device, d_model):
"""Create a TabChordModule instance for testing."""
from TouchGrass.models.tab_chord_module import TabChordModule
module = TabChordModule(d_model=d_model).to(device)
module.eval()
return module
@pytest.fixture
def ear_training_module(device, d_model):
"""Create an EarTrainingModule instance for testing."""
from TouchGrass.models.ear_training_module import EarTrainingModule
module = EarTrainingModule(d_model=d_model).to(device)
module.eval()
return module
@pytest.fixture
def eq_adapter_module(device, d_model):
"""Create a MusicEQAdapter instance for testing."""
from TouchGrass.models.eq_adapter import MusicEQAdapter
module = MusicEQAdapter(d_model=d_model).to(device)
module.eval()
return module
@pytest.fixture
def songwriting_module(device, d_model):
"""Create a SongwritingModule instance for testing."""
from TouchGrass.models.songwriting_module import SongwritingModule
module = SongwritingModule(d_model=d_model).to(device)
module.eval()
return module
@pytest.fixture
def music_qa_generator():
"""Create a MusicQAGenerator instance for testing."""
from TouchGrass.data.music_qa_generator import MusicQAGenerator
generator = MusicQAGenerator()
return generator
@pytest.fixture
def chat_formatter():
"""Create a ChatFormatter instance for testing."""
from TouchGrass.data.chat_formatter import ChatFormatter
formatter = ChatFormatter()
return formatter
@pytest.fixture
def touchgrass_loss():
"""Create a TouchGrassLoss instance for testing."""
from TouchGrass.training.losses import TouchGrassLoss
loss_fn = TouchGrassLoss(lm_loss_weight=1.0, eq_loss_weight=0.1, music_module_loss_weight=0.05)
return loss_fn
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line(
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
)
config.addinivalue_line(
"markers", "integration: marks tests as integration tests"
)
config.addinivalue_line(
"markers", "gpu: marks tests that require GPU"
)
def pytest_collection_modifyitems(config, items):
"""Modify test collection to add markers based on file names."""
for item in items:
if "test_inference" in item.nodeid:
item.add_marker(pytest.mark.integration)
if "test_trainer" in item.nodeid:
item.add_marker(pytest.mark.slow)