LightDiffusion-Next / tests /test_refiner_noise.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Verification script for SDXL refiner noise fix.
This script verifies that the Pipeline correctly sets disable_noise=True
when calling the refiner model's generate method.
"""
import sys
import unittest
import pytest
from unittest.mock import MagicMock, patch
from pathlib import Path
# Add project root to path
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
from src.Core import Context, Pipeline
from src.Core.AbstractModel import AbstractModel
@pytest.mark.unit
@pytest.mark.slow
class TestRefinerNoiseFix(unittest.TestCase):
@patch('src.Core.Pipeline.create_model')
@patch('src.FileManaging.ImageSaver.SaveImage')
def test_refiner_disables_noise(self, mock_saver, mock_create_model):
"""Verify that refiner pass calls generate with disable_noise=True."""
# 1. Setup mocks
mock_base = MagicMock(spec=AbstractModel)
mock_base.is_loaded = True
mock_base.model_path = "base.safetensors"
mock_base.capabilities.supports_lora = False
mock_base.generate.return_value = {"samples": "base_latents"}
mock_base.encode_prompt.return_value = ("pos", "neg")
mock_base.decode.return_value = ["base_image"]
mock_refiner = MagicMock(spec=AbstractModel)
mock_refiner.is_loaded = True
mock_refiner.model_path = "refiner.safetensors"
mock_refiner.capabilities.supports_lora = False
mock_refiner.generate.return_value = {"samples": "refined_latents"}
mock_refiner.encode_prompt.return_value = ("ref_pos", "ref_neg")
mock_refiner.decode.return_value = ["refined_image"]
# model_factory returns base first, then refiner
mock_create_model.side_effect = [mock_base, mock_refiner]
# 2. Create context with refiner
ctx = Context.from_kwargs(
prompt="test",
w=1024,
h=1024,
steps=20,
refiner_model_path="refiner.safetensors",
refiner_switch_step=16
)
# 3. Run pipeline
pipeline = Pipeline(model_factory=mock_create_model)
with patch('src.Core.Pipeline.logger'): # Silence logs
pipeline.run(ctx)
# 4. Verify base generate call: disable_noise should be default (False)
# Find the call to base model's generate
base_call = mock_base.generate.call_args_list[0]
self.assertEqual(base_call.kwargs.get('last_step'), 16)
# disable_noise is not passed by run() for base, so it should be missing or default
self.assertIsNone(base_call.kwargs.get('disable_noise'))
# 5. Verify refiner generate call: disable_noise MUST be True
refiner_call = mock_refiner.generate.call_args_list[0]
self.assertEqual(refiner_call.kwargs.get('start_step'), 16)
self.assertTrue(refiner_call.kwargs.get('disable_noise'), "Refiner must be called with disable_noise=True")
print("SUCCESS: Refiner correctly called with disable_noise=True")
if __name__ == "__main__":
unittest.main()