LightDiffusion-Next / tests /unit /test_sd15_component.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""
Unit tests for SD1.5 model components.
Tests the SD1.5 model configuration, latent format, CLIP tokenizer/encoder,
and CheckpointLoaderSimple with mocked weights.
"""
import os
import sys
import pytest
import torch
from pathlib import Path
from unittest.mock import patch, MagicMock, PropertyMock
# Add project root to path
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(project_root))
class TestSD15LatentFormat:
"""Test suite for SD15 latent format configuration."""
def test_sd15_latent_has_4_channels(self):
"""SD1.5 latent format should have 4 channels."""
from src.Utilities.Latent import SD15
latent = SD15()
assert latent.latent_channels == 4, (
f"Expected 4 latent channels, got {latent.latent_channels}"
)
def test_sd15_default_scale_factor(self):
"""SD1.5 should have default scale factor of 0.18215."""
from src.Utilities.Latent import SD15
latent = SD15()
assert abs(latent.scale_factor - 0.18215) < 1e-6, (
f"Expected scale factor ~0.18215, got {latent.scale_factor}"
)
def test_sd15_custom_scale_factor(self):
"""SD1.5 scale factor should be configurable."""
from src.Utilities.Latent import SD15
custom_scale = 0.2
latent = SD15(scale_factor=custom_scale)
assert abs(latent.scale_factor - custom_scale) < 1e-6, (
f"Expected scale factor {custom_scale}, got {latent.scale_factor}"
)
def test_sd15_has_rgb_factors(self):
"""SD1.5 should have latent RGB factors defined."""
from src.Utilities.Latent import SD15
latent = SD15()
assert hasattr(latent, 'latent_rgb_factors'), (
"SD15 should have latent_rgb_factors attribute"
)
assert len(latent.latent_rgb_factors) == 4, (
f"Expected 4 RGB factor rows, got {len(latent.latent_rgb_factors)}"
)
# Each row should have 3 values (R, G, B)
for row in latent.latent_rgb_factors:
assert len(row) == 3, f"Each RGB row should have 3 values, got {len(row)}"
def test_sd15_has_taesd_decoder_name(self):
"""SD1.5 should reference correct TAESD decoder."""
from src.Utilities.Latent import SD15
latent = SD15()
assert hasattr(latent, 'taesd_decoder_name'), (
"SD15 should have taesd_decoder_name attribute"
)
assert latent.taesd_decoder_name == "taesd_decoder", (
f"Expected 'taesd_decoder', got {latent.taesd_decoder_name}"
)
class TestSD15ModelConfig:
"""Test suite for SD1.5 model configuration (sm_SD15)."""
def test_sd15_unet_config_has_required_keys(self):
"""SD1.5 UNet config should have all required keys."""
from src.SD15.SD15 import sm_SD15
required_keys = [
"context_dim",
"model_channels",
"use_linear_in_transformer",
"adm_in_channels",
"use_temporal_attention",
]
for key in required_keys:
assert key in sm_SD15.unet_config, (
f"Missing required key '{key}' in SD15 unet_config"
)
def test_sd15_context_dim_is_768(self):
"""SD1.5 should use 768-dimensional context (CLIP embedding dim)."""
from src.SD15.SD15 import sm_SD15
assert sm_SD15.unet_config["context_dim"] == 768, (
f"Expected context_dim=768, got {sm_SD15.unet_config['context_dim']}"
)
def test_sd15_model_channels_is_320(self):
"""SD1.5 should use 320 model channels."""
from src.SD15.SD15 import sm_SD15
assert sm_SD15.unet_config["model_channels"] == 320, (
f"Expected model_channels=320, got {sm_SD15.unet_config['model_channels']}"
)
def test_sd15_no_linear_in_transformer(self):
"""SD1.5 should not use linear in transformer."""
from src.SD15.SD15 import sm_SD15
assert sm_SD15.unet_config["use_linear_in_transformer"] is False, (
"SD1.5 should not use linear in transformer"
)
def test_sd15_no_adm_channels(self):
"""SD1.5 should not have ADM channels (no pooled conditioning)."""
from src.SD15.SD15 import sm_SD15
assert sm_SD15.unet_config["adm_in_channels"] is None, (
f"SD1.5 should have adm_in_channels=None, got {sm_SD15.unet_config['adm_in_channels']}"
)
def test_sd15_no_temporal_attention(self):
"""SD1.5 should not use temporal attention."""
from src.SD15.SD15 import sm_SD15
assert sm_SD15.unet_config["use_temporal_attention"] is False, (
"SD1.5 should not use temporal attention"
)
def test_sd15_uses_correct_latent_format(self):
"""SD1.5 model config should reference SD15 latent format."""
from src.SD15.SD15 import sm_SD15
from src.Utilities.Latent import SD15 as SD15LatentFormat
assert sm_SD15.latent_format == SD15LatentFormat, (
f"SD1.5 model should use SD15 latent format"
)
def test_sd15_clip_target_returns_valid_target(self):
"""SD1.5 clip_target should return a ClipTarget."""
from src.SD15.SD15 import sm_SD15
from src.clip.Clip import ClipTarget
model = sm_SD15(sm_SD15.unet_config)
target = model.clip_target()
assert isinstance(target, ClipTarget), (
f"Expected ClipTarget, got {type(target)}"
)
def test_sd15_clip_target_uses_sd1_tokenizer(self):
"""SD1.5 should use SD1Tokenizer."""
from src.SD15.SD15 import sm_SD15
from src.SD15.SDToken import SD1Tokenizer
model = sm_SD15(sm_SD15.unet_config)
target = model.clip_target()
assert target.tokenizer == SD1Tokenizer, (
"SD1.5 should use SD1Tokenizer"
)
def test_sd15_clip_target_uses_sd1_clip_model(self):
"""SD1.5 should use SD1ClipModel."""
from src.SD15.SD15 import sm_SD15
from src.SD15.SDClip import SD1ClipModel
model = sm_SD15(sm_SD15.unet_config)
target = model.clip_target()
assert target.clip == SD1ClipModel, (
"SD1.5 should use SD1ClipModel"
)
class TestSD15CheckpointLoader:
"""Test suite for CheckpointLoaderSimple with SD1.5 models."""
def test_loader_instantiation(self):
"""CheckpointLoaderSimple should instantiate without errors."""
from src.FileManaging.Loader import CheckpointLoaderSimple
loader = CheckpointLoaderSimple()
assert loader is not None
@patch('src.FileManaging.Loader.load_checkpoint_guess_config')
@patch('src.Device.ModelCache.get_model_cache')
def test_loader_calls_correct_functions(self, mock_cache_fn, mock_load):
"""Loader should call cache check then load if not cached."""
from src.FileManaging.Loader import CheckpointLoaderSimple
# Setup mocks - use MagicMock directly
mock_cache_instance = MagicMock()
mock_cache_instance.get_cached_checkpoint.return_value = None
mock_cache_fn.return_value = mock_cache_instance
mock_model = MagicMock(name="mock_model")
mock_clip = MagicMock(name="mock_clip")
mock_vae = MagicMock(name="mock_vae")
mock_load.return_value = (mock_model, mock_clip, mock_vae, None)
loader = CheckpointLoaderSimple()
result = loader.load_checkpoint("test_model.safetensors")
# Verify cache was checked
mock_cache_instance.get_cached_checkpoint.assert_called_once()
# Verify load was called
mock_load.assert_called_once()
# Verify result is tuple of 3
assert len(result) == 3, f"Expected 3-tuple, got {len(result)}-tuple"
@patch('src.Device.ModelCache.get_model_cache')
def test_loader_returns_cached_model(self, mock_cache_fn):
"""Loader should return cached model without calling load."""
from src.FileManaging.Loader import CheckpointLoaderSimple
# Setup cached result using MagicMock
cached_model = MagicMock(name="cached_model")
cached_clip = MagicMock(name="cached_clip")
cached_vae = MagicMock(name="cached_vae")
mock_cache_instance = MagicMock()
mock_cache_instance.get_cached_checkpoint.return_value = (
cached_model, cached_clip, cached_vae
)
mock_cache_fn.return_value = mock_cache_instance
loader = CheckpointLoaderSimple()
result = loader.load_checkpoint("cached_model.safetensors")
# Verify cached result returned
assert result[0] is cached_model
assert result[1] is cached_clip
assert result[2] is cached_vae
def test_loader_accepts_vae_flag(self):
"""Loader should accept output_vae parameter."""
from src.FileManaging.Loader import CheckpointLoaderSimple
loader = CheckpointLoaderSimple()
# Should not raise TypeError for output_vae parameter
with patch('src.FileManaging.Loader.load_checkpoint_guess_config') as mock:
mock.return_value = (MagicMock(), MagicMock(), MagicMock(), None)
with patch('src.Device.ModelCache.get_model_cache') as cache:
cache.return_value.get_cached_checkpoint.return_value = None
# This should not raise
loader.load_checkpoint("test.safetensors", output_vae=False)
def test_loader_accepts_clip_flag(self):
"""Loader should accept output_clip parameter."""
from src.FileManaging.Loader import CheckpointLoaderSimple
loader = CheckpointLoaderSimple()
with patch('src.FileManaging.Loader.load_checkpoint_guess_config') as mock:
mock.return_value = (MagicMock(), MagicMock(), MagicMock(), None)
with patch('src.Device.ModelCache.get_model_cache') as cache:
cache.return_value.get_cached_checkpoint.return_value = None
# This should not raise
loader.load_checkpoint("test.safetensors", output_clip=False)
class TestSD15CLIPEncoding:
"""Test suite for SD1.5 CLIP text encoding (mocked)."""
def test_clip_text_encode_instantiation(self):
"""CLIPTextEncode should instantiate without errors."""
from src.clip.Clip import CLIPTextEncode
encoder = CLIPTextEncode()
assert encoder is not None
@patch('src.clip.Clip.CLIPTextEncode.encode')
def test_encode_returns_conditioning_format(self, mock_encode):
"""encode() should return list of [tensor, metadata] entries."""
from src.clip.Clip import CLIPTextEncode
# Mock the return value
mock_cond = torch.randn(1, 77, 768)
mock_metadata = {"pooled_output": None}
mock_encode.return_value = ([[mock_cond, mock_metadata]],)
encoder = CLIPTextEncode()
result = encoder.encode(text="test prompt", clip=MagicMock())
# Should be a tuple
assert isinstance(result, tuple), f"Expected tuple, got {type(result)}"
# First element should be list of conditioning entries
cond_list = result[0]
assert isinstance(cond_list, list), f"Expected list, got {type(cond_list)}"
@patch('src.clip.Clip.CLIPTextEncode.encode')
def test_encode_produces_768_dim_embeddings_for_sd15(self, mock_encode):
"""SD1.5 CLIP encoding should produce 768-dim embeddings."""
from src.clip.Clip import CLIPTextEncode
# SD1.5 uses 768-dim embeddings
expected_dim = 768
mock_cond = torch.randn(1, 77, expected_dim)
mock_encode.return_value = ([[mock_cond, {}]],)
encoder = CLIPTextEncode()
result = encoder.encode(text="test", clip=MagicMock())
cond_tensor = result[0][0][0]
assert cond_tensor.shape[-1] == expected_dim, (
f"Expected embedding dim {expected_dim}, got {cond_tensor.shape[-1]}"
)
class TestSD15EmptyLatent:
"""Test suite for EmptyLatentImage generation."""
def test_empty_latent_instantiation(self):
"""EmptyLatentImage should instantiate without errors."""
from src.Utilities.Latent import EmptyLatentImage
generator = EmptyLatentImage()
assert generator is not None
def test_empty_latent_generates_correct_shape(self):
"""EmptyLatentImage should generate correct latent dimensions."""
from src.Utilities.Latent import EmptyLatentImage
generator = EmptyLatentImage()
width, height = 512, 512
batch_size = 1
result = generator.generate(width=width, height=height, batch_size=batch_size)
# Result should be tuple with dict containing 'samples'
assert isinstance(result, tuple), f"Expected tuple, got {type(result)}"
latent_dict = result[0]
assert "samples" in latent_dict, "Result should have 'samples' key"
samples = latent_dict["samples"]
# For SD1.5: latent = image_size / 8
expected_shape = (batch_size, 4, height // 8, width // 8)
assert samples.shape == expected_shape, (
f"Expected shape {expected_shape}, got {samples.shape}"
)
def test_empty_latent_with_different_sizes(self):
"""EmptyLatentImage should work with various image sizes."""
from src.Utilities.Latent import EmptyLatentImage
generator = EmptyLatentImage()
test_cases = [
(512, 512, 1),
(768, 512, 1),
(1024, 1024, 1),
(512, 768, 2),
]
for width, height, batch in test_cases:
result = generator.generate(width=width, height=height, batch_size=batch)
samples = result[0]["samples"]
expected_shape = (batch, 4, height // 8, width // 8)
assert samples.shape == expected_shape, (
f"For {width}x{height} batch={batch}: "
f"expected {expected_shape}, got {samples.shape}"
)
def test_empty_latent_is_zeros(self):
"""EmptyLatentImage should produce zero-initialized latents."""
from src.Utilities.Latent import EmptyLatentImage
generator = EmptyLatentImage()
result = generator.generate(width=512, height=512, batch_size=1)
# EmptyLatentImage generates zeros - randomness comes from sampling noise
assert torch.allclose(result[0]["samples"], torch.zeros_like(result[0]["samples"])), (
"EmptyLatentImage should produce zero-initialized latents"
)
class TestSD15TokenizerBasics:
"""Test suite for SD1.5 tokenizer functionality."""
def test_sd1_tokenizer_class_exists(self):
"""SD1Tokenizer class should exist."""
from src.SD15.SDToken import SD1Tokenizer
assert SD1Tokenizer is not None
def test_sd_tokenizer_base_class_exists(self):
"""SDTokenizer base class should exist."""
from src.SD15.SDToken import SDTokenizer
assert SDTokenizer is not None
class TestSD15ProcessClipStateDict:
"""Test suite for CLIP state dict processing."""
def test_process_clip_state_dict_handles_prefix_replacement(self):
"""process_clip_state_dict should handle cond_stage_model prefix."""
from src.SD15.SD15 import sm_SD15
model = sm_SD15(sm_SD15.unet_config)
# Create dummy state dict with old prefix
state_dict = {
"cond_stage_model.transformer.text_model.weight": torch.randn(10, 10),
"cond_stage_model.other.weight": torch.randn(5, 5),
}
result = model.process_clip_state_dict(state_dict)
# After processing, keys should use clip_l prefix
for key in result.keys():
assert key.startswith("clip_l."), (
f"Expected key to start with 'clip_l.', got {key}"
)
def test_process_clip_state_dict_handles_position_ids_dtype(self):
"""process_clip_state_dict should convert float32 position_ids to int."""
from src.SD15.SD15 import sm_SD15
model = sm_SD15(sm_SD15.unet_config)
# Create state dict with float32 position_ids
pos_key = "cond_stage_model.transformer.text_model.embeddings.position_ids"
state_dict = {
pos_key: torch.arange(77).float(), # float32
}
result = model.process_clip_state_dict(state_dict)
# The position_ids should be processed (key may be renamed)
# Check that no float32 position_ids remain
for key, value in result.items():
if "position_ids" in key and value.dtype == torch.float32:
# Should be rounded (not exact floats like 0.1, 0.2, etc.)
rounded = value.round()
assert torch.allclose(value, rounded), (
"Float32 position_ids should be rounded"
)
class TestSD15SamplerIntegration:
"""Test suite for SD1.5 sampler integration (mocked)."""
def test_ksampler_instantiation(self):
"""KSampler should instantiate without errors."""
from src.sample.sampling import KSampler
sampler = KSampler()
assert sampler is not None
def test_ksampler_sample_signature_includes_required_params(self):
"""KSampler.sample should accept all required parameters."""
from src.sample.sampling import KSampler
import inspect
sampler = KSampler()
sig = inspect.signature(sampler.sample)
params = sig.parameters
required_params = [
'seed', 'steps', 'cfg', 'sampler_name', 'scheduler',
'denoise', 'model', 'positive', 'negative', 'latent_image'
]
for param in required_params:
assert param in params, (
f"KSampler.sample missing required parameter: {param}"
)
def test_ksampler_sample_accepts_pipeline_flag(self):
"""KSampler.sample should accept pipeline flag."""
from src.sample.sampling import KSampler
import inspect
sampler = KSampler()
sig = inspect.signature(sampler.sample)
assert 'pipeline' in sig.parameters, (
"KSampler.sample should accept 'pipeline' parameter"
)
class TestSD15ModelInModelsRegistry:
"""Test that SD1.5 model is properly registered."""
def test_sd15_in_models_list(self):
"""sm_SD15 should be in the models registry."""
from src.SD15.SD15 import models, sm_SD15
assert sm_SD15 in models, (
"sm_SD15 should be in the models registry list"
)
def test_models_list_not_empty(self):
"""Models list should contain multiple model types."""
from src.SD15.SD15 import models
assert len(models) > 0, "Models list should not be empty"
assert len(models) >= 3, (
f"Expected at least 3 model types, got {len(models)}"
)