rewrite / tests /test_model.py
morpheuslord's picture
Add files using upload-large-folder tool
3df5819 verified
"""Tests for the core model module."""
import pytest
import torch
from src.model.base_model import load_model_and_tokenizer, ENCODER_DECODER_MODELS, DECODER_ONLY_MODELS
from src.model.style_conditioner import StyleConditioner, prepend_style_prefix
from src.model.lora_adapter import create_lora_config
from peft import TaskType
def test_model_registry_populated():
"""Test that model registries are defined."""
assert len(ENCODER_DECODER_MODELS) > 0
assert len(DECODER_ONLY_MODELS) > 0
def test_invalid_model_key():
"""Test that unknown model keys raise ValueError."""
with pytest.raises(ValueError, match="Unknown model key"):
load_model_and_tokenizer("nonexistent-model")
def test_style_conditioner_output_shape():
"""Test that style conditioner produces correct tensor shapes."""
conditioner = StyleConditioner(style_dim=512, model_hidden_dim=256, n_prefix_tokens=5)
batch_size = 2
style_vec = torch.randn(batch_size, 512)
prefix = conditioner(style_vec)
assert prefix.shape == (batch_size, 5, 256)
def test_prepend_style_prefix():
"""Test prefix prepending dimensions."""
embeddings = torch.randn(2, 10, 256) # batch=2, seq=10, hidden=256
prefix = torch.randn(2, 5, 256) # batch=2, prefix=5, hidden=256
result = prepend_style_prefix(embeddings, prefix)
assert result.shape == (2, 15, 256)
def test_lora_config_creation():
"""Test LoRA config creation."""
config = create_lora_config(TaskType.SEQ_2_SEQ_LM, r=8, lora_alpha=16)
assert config.r == 8
assert config.lora_alpha == 16