"""Tests for the core model module.""" import pytest import torch from src.model.base_model import load_model_and_tokenizer, ENCODER_DECODER_MODELS, DECODER_ONLY_MODELS from src.model.style_conditioner import StyleConditioner, prepend_style_prefix from src.model.lora_adapter import create_lora_config from peft import TaskType def test_model_registry_populated(): """Test that model registries are defined.""" assert len(ENCODER_DECODER_MODELS) > 0 assert len(DECODER_ONLY_MODELS) > 0 def test_invalid_model_key(): """Test that unknown model keys raise ValueError.""" with pytest.raises(ValueError, match="Unknown model key"): load_model_and_tokenizer("nonexistent-model") def test_style_conditioner_output_shape(): """Test that style conditioner produces correct tensor shapes.""" conditioner = StyleConditioner(style_dim=512, model_hidden_dim=256, n_prefix_tokens=5) batch_size = 2 style_vec = torch.randn(batch_size, 512) prefix = conditioner(style_vec) assert prefix.shape == (batch_size, 5, 256) def test_prepend_style_prefix(): """Test prefix prepending dimensions.""" embeddings = torch.randn(2, 10, 256) # batch=2, seq=10, hidden=256 prefix = torch.randn(2, 5, 256) # batch=2, prefix=5, hidden=256 result = prepend_style_prefix(embeddings, prefix) assert result.shape == (2, 15, 256) def test_lora_config_creation(): """Test LoRA config creation.""" config = create_lora_config(TaskType.SEQ_2_SEQ_LM, r=8, lora_alpha=16) assert config.r == 8 assert config.lora_alpha == 16