| import torch |
| from typing_extensions import override |
|
|
| from comfy.k_diffusion.sampling import sigma_to_half_log_snr |
| from comfy_api.latest import ComfyExtension, io |
|
|
|
|
| class EpsilonScaling(io.ComfyNode): |
| """ |
| Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models' |
| (https://arxiv.org/abs/2308.15321v6). |
| |
| This method mitigates exposure bias by scaling the predicted noise during sampling, |
| which can significantly improve sample quality. This implementation uses the "uniform schedule" |
| recommended by the paper for its practicality and effectiveness. |
| """ |
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="Epsilon Scaling", |
| category="model_patches/unet", |
| inputs=[ |
| io.Model.Input("model"), |
| io.Float.Input( |
| "scaling_factor", |
| default=1.005, |
| min=0.5, |
| max=1.5, |
| step=0.001, |
| display_mode=io.NumberDisplay.number, |
| ), |
| ], |
| outputs=[ |
| io.Model.Output(), |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, model, scaling_factor) -> io.NodeOutput: |
| |
| if scaling_factor == 0: |
| scaling_factor = 1e-9 |
|
|
| def epsilon_scaling_function(args): |
| """ |
| This function is applied after the CFG guidance has been calculated. |
| It recalculates the denoised latent by scaling the predicted noise. |
| """ |
| denoised = args["denoised"] |
| x = args["input"] |
|
|
| noise_pred = x - denoised |
|
|
| scaled_noise_pred = noise_pred / scaling_factor |
|
|
| new_denoised = x - scaled_noise_pred |
|
|
| return new_denoised |
|
|
| |
| model_clone = model.clone() |
|
|
| model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function) |
|
|
| return io.NodeOutput(model_clone) |
|
|
|
|
| def compute_tsr_rescaling_factor( |
| snr: torch.Tensor, tsr_k: float, tsr_variance: float |
| ) -> torch.Tensor: |
| """Compute the rescaling score ratio in Temporal Score Rescaling. |
| |
| See equation (6) in https://arxiv.org/pdf/2510.01184v1. |
| """ |
| posinf_mask = torch.isposinf(snr) |
| rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1) |
| return torch.where(posinf_mask, tsr_k, rescaling_factor) |
|
|
|
|
| class TemporalScoreRescaling(io.ComfyNode): |
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="TemporalScoreRescaling", |
| display_name="TSR - Temporal Score Rescaling", |
| category="model_patches/unet", |
| inputs=[ |
| io.Model.Input("model"), |
| io.Float.Input( |
| "tsr_k", |
| tooltip=( |
| "Controls the rescaling strength.\n" |
| "Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling." |
| ), |
| default=0.95, |
| min=0.01, |
| max=100.0, |
| step=0.001, |
| display_mode=io.NumberDisplay.number, |
| ), |
| io.Float.Input( |
| "tsr_sigma", |
| tooltip=( |
| "Controls how early rescaling takes effect.\n" |
| "Larger values take effect earlier." |
| ), |
| default=1.0, |
| min=0.01, |
| max=100.0, |
| step=0.001, |
| display_mode=io.NumberDisplay.number, |
| ), |
| ], |
| outputs=[ |
| io.Model.Output( |
| display_name="patched_model", |
| ), |
| ], |
| description=( |
| "[Post-CFG Function]\n" |
| "TSR - Temporal Score Rescaling (2510.01184)\n\n" |
| "Rescaling the model's score or noise to steer the sampling diversity.\n" |
| ), |
| ) |
|
|
| @classmethod |
| def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput: |
| tsr_variance = tsr_sigma**2 |
|
|
| def temporal_score_rescaling(args): |
| denoised = args["denoised"] |
| x = args["input"] |
| sigma = args["sigma"] |
| curr_model = args["model"] |
|
|
| |
| if tsr_k == 1 or sigma == 0: |
| return denoised |
|
|
| model_sampling = curr_model.current_patcher.get_model_object("model_sampling") |
| half_log_snr = sigma_to_half_log_snr(sigma, model_sampling) |
| snr = (2 * half_log_snr).exp() |
|
|
| |
| if snr == 0: |
| return denoised |
|
|
| rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance) |
|
|
| |
| alpha = sigma * half_log_snr.exp() |
| return torch.lerp(x / alpha, denoised, rescaling_r) |
|
|
| m = model.clone() |
| m.set_model_sampler_post_cfg_function(temporal_score_rescaling) |
| return io.NodeOutput(m) |
|
|
|
|
| class EpsilonScalingExtension(ComfyExtension): |
| @override |
| async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| return [ |
| EpsilonScaling, |
| TemporalScoreRescaling, |
| ] |
|
|
|
|
| async def comfy_entrypoint() -> EpsilonScalingExtension: |
| return EpsilonScalingExtension() |
|
|