LightDiffusion-Next / tests /unit /test_sdxl_component.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""
Unit tests for SDXL model components.
Tests the SDXL model configuration, latent format, dual CLIP tokenizer/encoder,
resolution handling, and differences from SD1.5.
"""
import os
import sys
import pytest
import torch
from pathlib import Path
from unittest.mock import patch, MagicMock
# Add project root to path
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(project_root))
class TestSDXLLatentFormat:
"""Test suite for SDXL latent format configuration."""
def test_sdxl_latent_has_4_channels(self):
"""SDXL latent format should have 4 channels."""
from src.Utilities.Latent import SDXL
latent = SDXL()
assert latent.latent_channels == 4, (
f"Expected 4 latent channels, got {latent.latent_channels}"
)
def test_sdxl_scale_factor_differs_from_sd15(self):
"""SDXL scale factor should differ from SD1.5."""
from src.Utilities.Latent import SDXL, SD15
sdxl = SDXL()
sd15 = SD15()
assert sdxl.scale_factor != sd15.scale_factor, (
"SDXL scale factor should differ from SD1.5"
)
assert abs(sdxl.scale_factor - 0.13025) < 1e-6, (
f"Expected SDXL scale factor ~0.13025, got {sdxl.scale_factor}"
)
def test_sdxl_has_rgb_factors(self):
"""SDXL should have latent RGB factors defined."""
from src.Utilities.Latent import SDXL
latent = SDXL()
assert hasattr(latent, 'latent_rgb_factors'), (
"SDXL should have latent_rgb_factors attribute"
)
assert len(latent.latent_rgb_factors) == 4, (
f"Expected 4 RGB factor rows, got {len(latent.latent_rgb_factors)}"
)
def test_sdxl_has_rgb_factors_bias(self):
"""SDXL should have RGB factors bias (unlike SD1.5)."""
from src.Utilities.Latent import SDXL
latent = SDXL()
assert hasattr(latent, 'latent_rgb_factors_bias'), (
"SDXL should have latent_rgb_factors_bias attribute"
)
assert len(latent.latent_rgb_factors_bias) == 3, (
f"Expected 3 bias values (RGB), got {len(latent.latent_rgb_factors_bias)}"
)
def test_sdxl_uses_taesdxl_decoder(self):
"""SDXL should reference correct TAESD decoder (xl version)."""
from src.Utilities.Latent import SDXL
latent = SDXL()
assert hasattr(latent, 'taesd_decoder_name'), (
"SDXL should have taesd_decoder_name attribute"
)
assert latent.taesd_decoder_name == "taesdxl_decoder", (
f"Expected 'taesdxl_decoder', got {latent.taesd_decoder_name}"
)
class TestSDXLModelConfig:
"""Test suite for SDXL model configuration."""
def test_sdxl_class_exists(self):
"""SDXL model config class should exist."""
from src.SD15.SDXL import SDXL
assert SDXL is not None
def test_sdxl_unet_config_has_required_keys(self):
"""SDXL UNet config should have all required keys."""
from src.SD15.SDXL import SDXL
required_keys = [
"model_channels",
"use_linear_in_transformer",
"transformer_depth",
"context_dim",
"adm_in_channels",
"use_temporal_attention",
]
for key in required_keys:
assert key in SDXL.unet_config, (
f"Missing required key '{key}' in SDXL unet_config"
)
def test_sdxl_context_dim_is_2048(self):
"""SDXL should use 2048-dimensional context (concatenated CLIP dims)."""
from src.SD15.SDXL import SDXL
assert SDXL.unet_config["context_dim"] == 2048, (
f"Expected context_dim=2048, got {SDXL.unet_config['context_dim']}"
)
def test_sdxl_model_channels_is_320(self):
"""SDXL should use 320 model channels."""
from src.SD15.SDXL import SDXL
assert SDXL.unet_config["model_channels"] == 320, (
f"Expected model_channels=320, got {SDXL.unet_config['model_channels']}"
)
def test_sdxl_uses_linear_in_transformer(self):
"""SDXL should use linear in transformer (unlike SD1.5)."""
from src.SD15.SDXL import SDXL
assert SDXL.unet_config["use_linear_in_transformer"] is True, (
"SDXL should use linear in transformer"
)
def test_sdxl_has_adm_channels(self):
"""SDXL should have ADM channels for pooled conditioning."""
from src.SD15.SDXL import SDXL
assert SDXL.unet_config["adm_in_channels"] is not None, (
"SDXL should have adm_in_channels set"
)
assert SDXL.unet_config["adm_in_channels"] == 2816, (
f"Expected adm_in_channels=2816, got {SDXL.unet_config['adm_in_channels']}"
)
def test_sdxl_transformer_depth_is_list(self):
"""SDXL transformer depth should be a list."""
from src.SD15.SDXL import SDXL
depth = SDXL.unet_config["transformer_depth"]
assert isinstance(depth, (list, tuple)), (
f"Expected transformer_depth to be list, got {type(depth)}"
)
def test_sdxl_uses_correct_latent_format(self):
"""SDXL model config should reference SDXL latent format."""
from src.SD15.SDXL import SDXL as SDXLModel
from src.Utilities.Latent import SDXL as SDXLLatentFormat
assert SDXLModel.latent_format == SDXLLatentFormat, (
"SDXL model should use SDXL latent format"
)
def test_sdxl_has_memory_usage_factor(self):
"""SDXL should have memory_usage_factor defined."""
from src.SD15.SDXL import SDXL
assert hasattr(SDXL, 'memory_usage_factor'), (
"SDXL should have memory_usage_factor attribute"
)
assert 0 < SDXL.memory_usage_factor <= 1.0, (
f"memory_usage_factor should be 0-1, got {SDXL.memory_usage_factor}"
)
class TestSDXLClipTarget:
"""Test suite for SDXL CLIP target configuration."""
def test_sdxl_clip_target_returns_valid_target(self):
"""SDXL clip_target should return a ClipTarget."""
from src.SD15.SDXL import SDXL
from src.clip.Clip import ClipTarget
model = SDXL(SDXL.unet_config)
target = model.clip_target()
assert isinstance(target, ClipTarget), (
f"Expected ClipTarget, got {type(target)}"
)
def test_sdxl_clip_target_uses_sdxl_tokenizer(self):
"""SDXL should use SDXLTokenizer (dual tokenizer)."""
from src.SD15.SDXL import SDXL
from src.SD15.SDXLClip import SDXLTokenizer
model = SDXL(SDXL.unet_config)
target = model.clip_target()
assert target.tokenizer == SDXLTokenizer, (
f"SDXL should use SDXLTokenizer, got {target.tokenizer}"
)
def test_sdxl_clip_target_uses_sdxl_clip_model(self):
"""SDXL should use SDXLClipModel (dual CLIP model)."""
from src.SD15.SDXL import SDXL
from src.SD15.SDXLClip import SDXLClipModel
model = SDXL(SDXL.unet_config)
target = model.clip_target()
assert target.clip == SDXLClipModel, (
f"SDXL should use SDXLClipModel, got {target.clip}"
)
class TestSDXLTokenizer:
"""Test suite for SDXL dual tokenizer."""
def test_sdxl_tokenizer_exists(self):
"""SDXLTokenizer class should exist."""
from src.SD15.SDXLClip import SDXLTokenizer
assert SDXLTokenizer is not None
def test_sdxl_tokenizer_has_dual_tokenizers(self):
"""SDXLTokenizer should have both L and G tokenizers."""
from src.SD15.SDXLClip import SDXLTokenizer
tokenizer = SDXLTokenizer()
assert hasattr(tokenizer, 'clip_l'), (
"SDXLTokenizer should have clip_l attribute"
)
assert hasattr(tokenizer, 'clip_g'), (
"SDXLTokenizer should have clip_g attribute"
)
def test_sdxl_tokenizer_tokenize_returns_dict(self):
"""tokenize_with_weights should return dict with 'g' and 'l' keys."""
from src.SD15.SDXLClip import SDXLTokenizer
tokenizer = SDXLTokenizer()
# The tokenizer may need to load vocab files, so we mock or skip if files don't exist
try:
result = tokenizer.tokenize_with_weights("test prompt")
assert isinstance(result, dict), (
f"Expected dict result, got {type(result)}"
)
assert 'g' in result, "Result should have 'g' (ClipG) key"
assert 'l' in result, "Result should have 'l' (ClipL) key"
except FileNotFoundError:
pytest.skip("Tokenizer vocabulary files not available")
class TestSDXLClipModel:
"""Test suite for SDXL CLIP model (dual L+G)."""
def test_sdxl_clip_model_exists(self):
"""SDXLClipModel class should exist."""
from src.SD15.SDXLClip import SDXLClipModel
assert SDXLClipModel is not None
def test_sdxl_clip_model_is_torch_module(self):
"""SDXLClipModel should be a torch.nn.Module."""
from src.SD15.SDXLClip import SDXLClipModel
import torch.nn as nn
assert issubclass(SDXLClipModel, nn.Module), (
"SDXLClipModel should be a torch.nn.Module subclass"
)
@patch('src.SD15.SDClip.SDClipModel.__init__', return_value=None)
@patch('src.SD15.SDXLClip.SDXLClipG.__init__', return_value=None)
def test_sdxl_clip_model_has_dual_clips(self, mock_g_init, mock_l_init):
"""SDXLClipModel should have both clip_l and clip_g."""
from src.SD15.SDXLClip import SDXLClipModel
# Initialize with mocked sub-models
model = SDXLClipModel.__new__(SDXLClipModel)
model.clip_l = MagicMock()
model.clip_g = MagicMock()
model.dtypes = set()
assert hasattr(model, 'clip_l'), "Should have clip_l"
assert hasattr(model, 'clip_g'), "Should have clip_g"
class TestSDXLDifferencesFromSD15:
"""Test that SDXL properly differs from SD1.5 where expected."""
def test_context_dim_difference(self):
"""SDXL context_dim (2048) should differ from SD1.5 (768)."""
from src.SD15.SD15 import sm_SD15
from src.SD15.SDXL import SDXL
sd15_ctx = sm_SD15.unet_config["context_dim"]
sdxl_ctx = SDXL.unet_config["context_dim"]
assert sdxl_ctx != sd15_ctx, (
f"SDXL context_dim ({sdxl_ctx}) should differ from SD1.5 ({sd15_ctx})"
)
assert sdxl_ctx == 2048, f"SDXL should have context_dim=2048"
assert sd15_ctx == 768, f"SD1.5 should have context_dim=768"
def test_linear_in_transformer_difference(self):
"""SDXL uses linear in transformer, SD1.5 does not."""
from src.SD15.SD15 import sm_SD15
from src.SD15.SDXL import SDXL
sd15_linear = sm_SD15.unet_config["use_linear_in_transformer"]
sdxl_linear = SDXL.unet_config["use_linear_in_transformer"]
assert sd15_linear is False, "SD1.5 should not use linear in transformer"
assert sdxl_linear is True, "SDXL should use linear in transformer"
def test_adm_channels_difference(self):
"""SDXL has ADM channels, SD1.5 does not."""
from src.SD15.SD15 import sm_SD15
from src.SD15.SDXL import SDXL
sd15_adm = sm_SD15.unet_config["adm_in_channels"]
sdxl_adm = SDXL.unet_config["adm_in_channels"]
assert sd15_adm is None, "SD1.5 should not have ADM channels"
assert sdxl_adm is not None, "SDXL should have ADM channels"
assert sdxl_adm == 2816, f"SDXL ADM channels should be 2816, got {sdxl_adm}"
class TestSDXLResolutionHandling:
"""CRITICAL: Test SDXL resolution handling and warnings for small resolutions."""
def test_sdxl_latent_accepts_512_resolution(self):
"""SDXL latent generation should accept 512x512 resolution.
CRITICAL TEST: Per requirements, verify whether SDXL accepts or warns
for resolutions like 512x512 (below SDXL's native 1024x1024).
"""
from src.Utilities.Latent import EmptyLatentImage
generator = EmptyLatentImage()
# This should not raise an error even for small resolution
result = generator.generate(width=512, height=512, batch_size=1)
assert result is not None, "SDXL latent generation should succeed for 512x512"
samples = result[0]["samples"]
# Verify shape: 512/8 = 64
expected_shape = (1, 4, 64, 64)
assert samples.shape == expected_shape, (
f"Expected shape {expected_shape}, got {samples.shape}"
)
def test_sdxl_latent_accepts_1024_resolution(self):
"""SDXL latent generation should work naturally at 1024x1024."""
from src.Utilities.Latent import EmptyLatentImage
generator = EmptyLatentImage()
result = generator.generate(width=1024, height=1024, batch_size=1)
samples = result[0]["samples"]
# 1024/8 = 128
expected_shape = (1, 4, 128, 128)
assert samples.shape == expected_shape, (
f"Expected shape {expected_shape}, got {samples.shape}"
)
def test_sdxl_latent_handles_non_64_divisible_sizes(self):
"""Test latent generation with sizes not perfectly divisible."""
from src.Utilities.Latent import EmptyLatentImage
generator = EmptyLatentImage()
# 800 / 8 = 100, 600 / 8 = 75
result = generator.generate(width=800, height=600, batch_size=1)
samples = result[0]["samples"]
# Should truncate to integer
expected_shape = (1, 4, 75, 100) # (batch, channels, h/8, w/8)
assert samples.shape == expected_shape, (
f"Expected shape {expected_shape}, got {samples.shape}"
)
def test_sdxl_recommended_resolution_is_1024(self):
"""Document that SDXL's recommended base resolution is 1024x1024.
Note: This is a documentation test. SDXL models are trained at
1024x1024 and may produce suboptimal results at smaller sizes.
"""
# This is informational - SDXL works at any resolution but
# is optimized for 1024x1024
SDXL_RECOMMENDED_BASE = 1024
assert SDXL_RECOMMENDED_BASE == 1024
class TestSDXLRefiner:
"""Test suite for SDXL Refiner model configuration."""
def test_sdxl_refiner_exists(self):
"""SDXLRefiner class should exist."""
from src.SD15.SDXL import SDXLRefiner
assert SDXLRefiner is not None
def test_sdxl_refiner_has_different_unet_config(self):
"""SDXLRefiner should have different UNet config than base SDXL."""
from src.SD15.SDXL import SDXL, SDXLRefiner
# Refiner has different model_channels
assert SDXLRefiner.unet_config["model_channels"] != SDXL.unet_config["model_channels"], (
"Refiner should have different model_channels than base SDXL"
)
assert SDXLRefiner.unet_config["model_channels"] == 384, (
f"Expected refiner model_channels=384, got {SDXLRefiner.unet_config['model_channels']}"
)
def test_sdxl_refiner_uses_g_clip_only(self):
"""SDXL Refiner should use only ClipG (not L+G like base SDXL)."""
from src.SD15.SDXL import SDXLRefiner
from src.SD15.SDXLClip import SDXLRefinerClipModel
model = SDXLRefiner(SDXLRefiner.unet_config)
target = model.clip_target()
assert target.clip == SDXLRefinerClipModel, (
"SDXL Refiner should use SDXLRefinerClipModel"
)
class TestSDXLModelType:
"""Test suite for SDXL model type detection from state dict."""
def test_sdxl_model_type_default_is_eps(self):
"""Default model type should be EPS."""
from src.SD15.SDXL import SDXL
from src.sample.sampling import ModelType
model = SDXL(SDXL.unet_config)
# With empty state dict, should return EPS
result = model.model_type({}, "")
assert result == ModelType.EPS, (
f"Expected default ModelType.EPS, got {result}"
)
def test_sdxl_model_type_detects_v_prediction(self):
"""Model type should detect V_PREDICTION from state dict.
Note: This test documents expected behavior but skips if
ModelType.V_PREDICTION is not defined in the sampling.ModelType enum.
"""
from src.SD15.SDXL import SDXL
from src.sample.sampling import ModelType
# Skip if V_PREDICTION is not in ModelType
if not hasattr(ModelType, 'V_PREDICTION'):
pytest.skip("ModelType.V_PREDICTION not defined in enum")
model = SDXL(SDXL.unet_config)
state_dict = {"v_pred": torch.tensor([1.0])}
result = model.model_type(state_dict, "")
assert result == ModelType.V_PREDICTION, (
f"Expected ModelType.V_PREDICTION, got {result}"
)
def test_sdxl_model_type_detects_edm(self):
"""Model type should detect EDM (Playground V2.5) from state dict.
Note: This test documents expected behavior but skips if
ModelType.EDM is not defined in the sampling.ModelType enum.
"""
from src.SD15.SDXL import SDXL
from src.sample.sampling import ModelType
# Skip if EDM is not in ModelType
if not hasattr(ModelType, 'EDM'):
pytest.skip("ModelType.EDM not defined in enum")
model = SDXL(SDXL.unet_config)
state_dict = {
"edm_mean": torch.tensor([0.5]),
"edm_std": torch.tensor([1.0]),
}
result = model.model_type(state_dict, "")
assert result == ModelType.EDM, (
f"Expected ModelType.EDM, got {result}"
)
class TestSDXLVariants:
"""Test suite for SDXL variant models (SSD-1B, Segmind Vega, etc.)."""
def test_ssd1b_exists_and_inherits_sdxl(self):
"""SSD-1B should exist and inherit from SDXL."""
from src.SD15.SDXL import SSD1B, SDXL
assert SSD1B is not None
assert issubclass(SSD1B, SDXL), (
"SSD1B should inherit from SDXL"
)
def test_ssd1b_has_reduced_transformer_depth(self):
"""SSD-1B should have fewer transformer blocks."""
from src.SD15.SDXL import SSD1B, SDXL
ssd1b_depth = SSD1B.unet_config["transformer_depth"]
sdxl_depth = SDXL.unet_config["transformer_depth"]
# SSD-1B has [0, 0, 2, 2, 4, 4] vs SDXL's [0, 0, 2, 2, 10, 10]
assert sum(ssd1b_depth) < sum(sdxl_depth), (
f"SSD-1B should have fewer total transformer blocks"
)
def test_segmind_vega_exists(self):
"""Segmind Vega model should exist."""
from src.SD15.SDXL import Segmind_Vega
assert Segmind_Vega is not None
def test_koala_models_exist(self):
"""KOALA 700M and 1B models should exist."""
from src.SD15.SDXL import KOALA_700M, KOALA_1B
assert KOALA_700M is not None
assert KOALA_1B is not None
class TestSDXLProcessClipStateDict:
"""Test suite for SDXL CLIP state dict processing."""
def test_sdxl_process_clip_state_dict_handles_dual_embedders(self):
"""process_clip_state_dict should handle both L and G embedders."""
from src.SD15.SDXL import SDXL
model = SDXL(SDXL.unet_config)
# Create dummy state dict with SDXL-style prefixes
state_dict = {
"conditioner.embedders.0.transformer.text_model.weight": torch.randn(10, 10),
"conditioner.embedders.1.model.layer": torch.randn(5, 5),
}
result = model.process_clip_state_dict(state_dict)
# Keys should be converted to clip_l and clip_g prefixes
prefixes = set()
for key in result.keys():
prefix = key.split('.')[0]
prefixes.add(prefix)
# Should have both clip_l and clip_g prefixes after processing
assert 'clip_l' in prefixes or 'clip_g' in prefixes, (
f"Expected clip_l or clip_g prefixes, got prefixes: {prefixes}"
)
class TestSDXLInModelsRegistry:
"""Test that SDXL variants are properly registered."""
def test_sdxl_in_models_list(self):
"""SDXL should be in the models registry."""
from src.SD15.SD15 import models
from src.SD15.SDXL import SDXL
assert SDXL in models, "SDXL should be in the models registry"
def test_sdxl_refiner_in_models_list(self):
"""SDXLRefiner should be in the models registry."""
from src.SD15.SD15 import models
from src.SD15.SDXL import SDXLRefiner
assert SDXLRefiner in models, "SDXLRefiner should be in the models registry"
def test_sdxl_variants_in_models_list(self):
"""SDXL variants (SSD1B, etc.) should be in the models registry."""
from src.SD15.SD15 import models
from src.SD15.SDXL import SSD1B, Segmind_Vega, KOALA_700M, KOALA_1B
for variant in [SSD1B, Segmind_Vega, KOALA_700M, KOALA_1B]:
assert variant in models, (
f"{variant.__name__} should be in the models registry"
)