| """Tests for the core model module.""" |
|
|
| import pytest |
| import torch |
| from src.model.base_model import load_model_and_tokenizer, ENCODER_DECODER_MODELS, DECODER_ONLY_MODELS |
| from src.model.style_conditioner import StyleConditioner, prepend_style_prefix |
| from src.model.lora_adapter import create_lora_config |
| from peft import TaskType |
|
|
|
|
| def test_model_registry_populated(): |
| """Test that model registries are defined.""" |
| assert len(ENCODER_DECODER_MODELS) > 0 |
| assert len(DECODER_ONLY_MODELS) > 0 |
|
|
|
|
| def test_invalid_model_key(): |
| """Test that unknown model keys raise ValueError.""" |
| with pytest.raises(ValueError, match="Unknown model key"): |
| load_model_and_tokenizer("nonexistent-model") |
|
|
|
|
| def test_style_conditioner_output_shape(): |
| """Test that style conditioner produces correct tensor shapes.""" |
| conditioner = StyleConditioner(style_dim=512, model_hidden_dim=256, n_prefix_tokens=5) |
| batch_size = 2 |
| style_vec = torch.randn(batch_size, 512) |
| prefix = conditioner(style_vec) |
| assert prefix.shape == (batch_size, 5, 256) |
|
|
|
|
| def test_prepend_style_prefix(): |
| """Test prefix prepending dimensions.""" |
| embeddings = torch.randn(2, 10, 256) |
| prefix = torch.randn(2, 5, 256) |
| result = prepend_style_prefix(embeddings, prefix) |
| assert result.shape == (2, 15, 256) |
|
|
|
|
| def test_lora_config_creation(): |
| """Test LoRA config creation.""" |
| config = create_lora_config(TaskType.SEQ_2_SEQ_LM, r=8, lora_alpha=16) |
| assert config.r == 8 |
| assert config.lora_alpha == 16 |
|
|