File size: 5,791 Bytes
359fa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
from typing_extensions import override

from comfy.k_diffusion.sampling import sigma_to_half_log_snr
from comfy_api.latest import ComfyExtension, io


class EpsilonScaling(io.ComfyNode):
    """
    Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
    (https://arxiv.org/abs/2308.15321v6).

    This method mitigates exposure bias by scaling the predicted noise during sampling,
    which can significantly improve sample quality. This implementation uses the "uniform schedule"
    recommended by the paper for its practicality and effectiveness.
    """
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="Epsilon Scaling",
            category="model_patches/unet",
            inputs=[
                io.Model.Input("model"),
                io.Float.Input(
                    "scaling_factor",
                    default=1.005,
                    min=0.5,
                    max=1.5,
                    step=0.001,
                    display_mode=io.NumberDisplay.number,
                ),
            ],
            outputs=[
                io.Model.Output(),
            ],
        )

    @classmethod
    def execute(cls, model, scaling_factor) -> io.NodeOutput:
        # Prevent division by zero, though the UI's min value should prevent this.
        if scaling_factor == 0:
            scaling_factor = 1e-9

        def epsilon_scaling_function(args):
            """
            This function is applied after the CFG guidance has been calculated.
            It recalculates the denoised latent by scaling the predicted noise.
            """
            denoised = args["denoised"]
            x = args["input"]

            noise_pred = x - denoised

            scaled_noise_pred = noise_pred / scaling_factor

            new_denoised = x - scaled_noise_pred

            return new_denoised

        # Clone the model patcher to avoid modifying the original model in place
        model_clone = model.clone()

        model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)

        return io.NodeOutput(model_clone)


def compute_tsr_rescaling_factor(
    snr: torch.Tensor, tsr_k: float, tsr_variance: float
) -> torch.Tensor:
    """Compute the rescaling score ratio in Temporal Score Rescaling.

    See equation (6) in https://arxiv.org/pdf/2510.01184v1.
    """
    posinf_mask = torch.isposinf(snr)
    rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
    return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k


class TemporalScoreRescaling(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="TemporalScoreRescaling",
            display_name="TSR - Temporal Score Rescaling",
            category="model_patches/unet",
            inputs=[
                io.Model.Input("model"),
                io.Float.Input(
                    "tsr_k",
                    tooltip=(
                        "Controls the rescaling strength.\n"
                        "Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling."
                    ),
                    default=0.95,
                    min=0.01,
                    max=100.0,
                    step=0.001,
                    display_mode=io.NumberDisplay.number,
                ),
                io.Float.Input(
                    "tsr_sigma",
                    tooltip=(
                        "Controls how early rescaling takes effect.\n"
                        "Larger values take effect earlier."
                    ),
                    default=1.0,
                    min=0.01,
                    max=100.0,
                    step=0.001,
                    display_mode=io.NumberDisplay.number,
                ),
            ],
            outputs=[
                io.Model.Output(
                    display_name="patched_model",
                ),
            ],
            description=(
                "[Post-CFG Function]\n"
                "TSR - Temporal Score Rescaling (2510.01184)\n\n"
                "Rescaling the model's score or noise to steer the sampling diversity.\n"
            ),
        )

    @classmethod
    def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
        tsr_variance = tsr_sigma**2

        def temporal_score_rescaling(args):
            denoised = args["denoised"]
            x = args["input"]
            sigma = args["sigma"]
            curr_model = args["model"]

            # No rescaling (r = 1) or no noise
            if tsr_k == 1 or sigma == 0:
                return denoised

            model_sampling = curr_model.current_patcher.get_model_object("model_sampling")
            half_log_snr = sigma_to_half_log_snr(sigma, model_sampling)
            snr = (2 * half_log_snr).exp()

            # No rescaling needed (r = 1)
            if snr == 0:
                return denoised

            rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance)

            # Derived from scaled_denoised = (x - r * sigma * noise) / alpha
            alpha = sigma * half_log_snr.exp()
            return torch.lerp(x / alpha, denoised, rescaling_r)

        m = model.clone()
        m.set_model_sampler_post_cfg_function(temporal_score_rescaling)
        return io.NodeOutput(m)


class EpsilonScalingExtension(ComfyExtension):
    @override
    async def get_node_list(self) -> list[type[io.ComfyNode]]:
        return [
            EpsilonScaling,
            TemporalScoreRescaling,
        ]


async def comfy_entrypoint() -> EpsilonScalingExtension:
    return EpsilonScalingExtension()