File size: 9,412 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""Classifier-Free Guidance implementation."""
import math
import logging
import torch
from src.cond import cond, cond_util


def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
    """Apply classifier-free guidance to predictions."""
    # Dynamic CFG rescaling
    if "cfg_guider" in model_options:
        guider = model_options["cfg_guider"]
        if hasattr(guider, "dynamic_cfg_rescaling") and guider.dynamic_cfg_rescaling:
            cond_scale = guider._apply_dynamic_cfg_rescaling(cond_pred, uncond_pred, cond_scale)
    
    # Custom sampler CFG
    if "sampler_cfg_function" in model_options:
        cfg_result = x - model_options["sampler_cfg_function"]({
            "cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale,
            "timestep": timestep, "input": x, "sigma": timestep,
            "cond_denoised": cond_pred, "uncond_denoised": uncond_pred,
            "model": model, "model_options": model_options,
        })
    elif math.isclose(cond_scale, 1.0):
        cfg_result = cond_pred
    else:
        cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale)

    # Post-CFG functions
    for fn in model_options.get("sampler_post_cfg_function", []):
        cfg_result = fn({
            "denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model,
            "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
            "sigma": timestep, "model_options": model_options, "input": x,
        })
    return cfg_result


def sampling_function(model, x, timestep, uncond, condo, cond_scale, model_options={}, seed=None):
    """Perform sampling with CFG."""
    uncond_ = None if (math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False)) else uncond
    cond_outputs = cond.calc_cond_batch(model, [condo, uncond_], x, timestep, model_options)

    # Pre-CFG functions
    for fn in model_options.get("sampler_pre_cfg_function", []):
        cond_outputs = fn({
            "conds": [condo, uncond_], "conds_out": cond_outputs, "cond_scale": cond_scale,
            "timestep": timestep, "input": x, "sigma": timestep,
            "model": model, "model_options": model_options,
        })

    return cfg_function(model, cond_outputs[0], cond_outputs[1], cond_scale, x, timestep, model_options, condo, uncond_)


class CFGGuider:
    """Guidance with Classifier-Free Guidance."""

    def __init__(self, model_patcher, flux=False, dynamic_cfg_rescaling=False, dynamic_cfg_method="variance",
                 dynamic_cfg_percentile=95.0, dynamic_cfg_target_scale=7.0, 
                 adaptive_noise_enabled=False, adaptive_noise_method="complexity"):
        self.model_patcher = model_patcher
        inner_model = getattr(model_patcher, "model", model_patcher)
        self.model_options = getattr(
            model_patcher,
            "model_options",
            getattr(inner_model, "model_options", {}),
        )
        self.original_conds = {}
        self.cfg = 1.0
        self.cfg_free_enabled = False
        self.cfg_free_start_percent = 70.0
        self.original_cfg = 1.0
        self.sigmas = None
        self.flux = flux  # Flag for FLUX model behavior
        self.dynamic_cfg_rescaling = dynamic_cfg_rescaling
        self.dynamic_cfg_method = dynamic_cfg_method
        self.dynamic_cfg_percentile = dynamic_cfg_percentile
        self.dynamic_cfg_target_scale = dynamic_cfg_target_scale
        self.adaptive_noise_enabled = adaptive_noise_enabled
        self.adaptive_noise_method = adaptive_noise_method
        self.complexity_history = []
        self.base_sigmas = None

    def set_conds(self, positive, negative):
        self.inner_set_conds({"positive": positive, "negative": negative})

    def set_cfg(self, cfg):
        self.cfg = cfg
        self.original_cfg = cfg

    def set_cfg_free_params(self, enabled=False, start_percent=70.0):
        self.cfg_free_enabled = enabled
        self.cfg_free_start_percent = max(0.0, min(100.0, start_percent))
        if enabled:
            print(f"CFG-Free sampling ACTIVE: CFG will reduce to 0 starting at {start_percent:.0f}% of steps")

    def _update_cfg_for_sigma(self, sigma):
        """Update CFG based on current sigma for CFG-free sampling."""
        if not self.cfg_free_enabled or self.sigmas is None or len(self.sigmas) <= 1:
            return

        total_steps = len(self.sigmas) - 1
        current_step = min(range(len(self.sigmas)), key=lambda i: abs(float(self.sigmas[i]) - float(sigma)))
        progress = (current_step / total_steps) * 100.0 if total_steps > 0 else 0

        if progress >= self.cfg_free_start_percent:
            remaining = 100.0 - self.cfg_free_start_percent
            if remaining > 0:
                self.cfg = max(0.0, self.original_cfg * (1.0 - (progress - self.cfg_free_start_percent) / remaining))
        else:
            self.cfg = self.original_cfg

    def _apply_dynamic_cfg_rescaling(self, cond_pred, uncond_pred, cond_scale):
        """Apply dynamic CFG rescaling."""
        if not self.dynamic_cfg_rescaling:
            return cond_scale
            
        diff = cond_pred - uncond_pred
        if self.dynamic_cfg_method == "variance":
            variance = min(torch.var(diff).item() / 0.1, 10.0)
            adjusted = cond_scale / (1.0 + variance * 0.1)
        elif self.dynamic_cfg_method == "range":
            low = torch.quantile(diff.flatten(), (100 - self.dynamic_cfg_percentile) / 100).item()
            high = torch.quantile(diff.flatten(), self.dynamic_cfg_percentile / 100).item()
            adjusted = min(cond_scale / max(high - low, 0.01), self.dynamic_cfg_target_scale)
        else:
            adjusted = cond_scale
        return max(1.0, min(adjusted, 20.0))

    def inner_set_conds(self, conds):
        for k in conds:
            self.original_conds[k] = cond.convert_cond(conds[k])

    def __call__(self, *args, **kwargs):
        return self.predict_noise(*args, **kwargs)

    def predict_noise(self, x, timestep, model_options={}, seed=None):
        if self.cfg_free_enabled:
            self._update_cfg_for_sigma(timestep)
        
        opts = {**model_options, "cfg_guider": self}
        result = sampling_function(self.inner_model, x, timestep, 
                                   self.conds.get("negative"), self.conds.get("positive"),
                                   self.cfg, model_options=opts, seed=seed)
        
        if self.adaptive_noise_enabled:
            self.complexity_history.append(self._calc_complexity(result))
        return result

    def _calc_complexity(self, prediction):
        """Calculate complexity for adaptive noise."""
        if self.adaptive_noise_method == "complexity":
            dx = prediction[:, :, :, 1:] - prediction[:, :, :, :-1]
            dy = prediction[:, :, 1:, :] - prediction[:, :, :-1, :]
            h, w = min(dx.shape[2], dy.shape[2]), min(dx.shape[3], dy.shape[3])
            return (dx[:, :, :h, :w].abs() + dy[:, :, :h, :w].abs()).mean().item()
        return prediction.var(dim=[2, 3]).mean().item()

    def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, pipeline=False):
        if latent_image is not None and torch.count_nonzero(latent_image) > 0:
            latent_image = self.inner_model.process_latent_in(latent_image)

        self.conds = cond.process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
        self.sigmas = sigmas
        
        if self.adaptive_noise_enabled and len(self.complexity_history) > 0:
            if self.base_sigmas is None:
                self.base_sigmas = sigmas.clone()
            avg = sum(self.complexity_history) / len(self.complexity_history)
            sigmas = self.base_sigmas * (1.0 + (avg / max(0.01, avg + 0.1)) * 0.5)

        samples = sampler.sample(self, sigmas, {"model_options": self.model_options, "seed": seed},
                                 callback, noise, latent_image, denoise_mask, disable_pbar, pipeline=pipeline)
        return self.inner_model.process_latent_out(samples.to(torch.float32))

    def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, 
               disable_pbar=False, seed=None, pipeline=False):
        self.conds = {k: [a.copy() for a in v] for k, v in self.original_conds.items()}
        self.inner_model, self.conds, self.loaded_models = cond_util.prepare_sampling(
            self.model_patcher, noise.shape, self.conds)
        inner_patcher = getattr(self.model_patcher, "model", self.model_patcher)
        device = getattr(self.model_patcher, "load_device", getattr(inner_patcher, "load_device", None))
        
        # Handle mock objects in tests
        if not isinstance(device, (torch.device, str)):
            from src.Device import Device
            device = Device.get_torch_device()

        output = self.inner_sample(noise.to(device), latent_image.to(device), device, sampler,
                                   sigmas.to(device), denoise_mask, callback, disable_pbar, seed, pipeline)

        from src.Device.ModelCache import get_model_cache
        get_model_cache().prevent_model_cleanup(self.conds, self.loaded_models)
        del self.inner_model, self.conds, self.loaded_models
        return output