File size: 3,132 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Verification script for SDXL refiner noise fix.

This script verifies that the Pipeline correctly sets disable_noise=True
when calling the refiner model's generate method.
"""

import sys
import unittest
import pytest
from unittest.mock import MagicMock, patch
from pathlib import Path

# Add project root to path
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))

from src.Core import Context, Pipeline
from src.Core.AbstractModel import AbstractModel

@pytest.mark.unit
@pytest.mark.slow
class TestRefinerNoiseFix(unittest.TestCase):
    
    @patch('src.Core.Pipeline.create_model')
    @patch('src.FileManaging.ImageSaver.SaveImage')
    def test_refiner_disables_noise(self, mock_saver, mock_create_model):
        """Verify that refiner pass calls generate with disable_noise=True."""
        
        # 1. Setup mocks
        mock_base = MagicMock(spec=AbstractModel)
        mock_base.is_loaded = True
        mock_base.model_path = "base.safetensors"
        mock_base.capabilities.supports_lora = False
        mock_base.generate.return_value = {"samples": "base_latents"}
        mock_base.encode_prompt.return_value = ("pos", "neg")
        mock_base.decode.return_value = ["base_image"]
        
        mock_refiner = MagicMock(spec=AbstractModel)
        mock_refiner.is_loaded = True
        mock_refiner.model_path = "refiner.safetensors"
        mock_refiner.capabilities.supports_lora = False
        mock_refiner.generate.return_value = {"samples": "refined_latents"}
        mock_refiner.encode_prompt.return_value = ("ref_pos", "ref_neg")
        mock_refiner.decode.return_value = ["refined_image"]
        
        # model_factory returns base first, then refiner
        mock_create_model.side_effect = [mock_base, mock_refiner]
        
        # 2. Create context with refiner
        ctx = Context.from_kwargs(
            prompt="test",
            w=1024,
            h=1024,
            steps=20,
            refiner_model_path="refiner.safetensors",
            refiner_switch_step=16
        )
        
        # 3. Run pipeline
        pipeline = Pipeline(model_factory=mock_create_model)
        with patch('src.Core.Pipeline.logger'): # Silence logs
            pipeline.run(ctx)
        
        # 4. Verify base generate call: disable_noise should be default (False)
        # Find the call to base model's generate
        base_call = mock_base.generate.call_args_list[0]
        self.assertEqual(base_call.kwargs.get('last_step'), 16)
        # disable_noise is not passed by run() for base, so it should be missing or default
        self.assertIsNone(base_call.kwargs.get('disable_noise'))
        
        # 5. Verify refiner generate call: disable_noise MUST be True
        refiner_call = mock_refiner.generate.call_args_list[0]
        self.assertEqual(refiner_call.kwargs.get('start_step'), 16)
        self.assertTrue(refiner_call.kwargs.get('disable_noise'), "Refiner must be called with disable_noise=True")
        
        print("SUCCESS: Refiner correctly called with disable_noise=True")

if __name__ == "__main__":
    unittest.main()