| import torch |
| from typing_extensions import override |
|
|
| from comfy_api.latest import ComfyExtension, io |
|
|
|
|
| def project(v0, v1): |
| v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) |
| v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 |
| v0_orthogonal = v0 - v0_parallel |
| return v0_parallel, v0_orthogonal |
|
|
| class APG(io.ComfyNode): |
| @classmethod |
| def define_schema(cls) -> io.Schema: |
| return io.Schema( |
| node_id="APG", |
| display_name="Adaptive Projected Guidance", |
| category="sampling/custom_sampling", |
| inputs=[ |
| io.Model.Input("model"), |
| io.Float.Input( |
| "eta", |
| default=1.0, |
| min=-10.0, |
| max=10.0, |
| step=0.01, |
| tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.", |
| ), |
| io.Float.Input( |
| "norm_threshold", |
| default=5.0, |
| min=0.0, |
| max=50.0, |
| step=0.1, |
| tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.", |
| ), |
| io.Float.Input( |
| "momentum", |
| default=0.0, |
| min=-5.0, |
| max=1.0, |
| step=0.01, |
| tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.", |
| ), |
| ], |
| outputs=[io.Model.Output()], |
| ) |
|
|
| @classmethod |
| def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput: |
| running_avg = 0 |
| prev_sigma = None |
|
|
| def pre_cfg_function(args): |
| nonlocal running_avg, prev_sigma |
|
|
| if len(args["conds_out"]) == 1: return args["conds_out"] |
|
|
| cond = args["conds_out"][0] |
| uncond = args["conds_out"][1] |
| sigma = args["sigma"][0] |
| cond_scale = args["cond_scale"] |
|
|
| if prev_sigma is not None and sigma > prev_sigma: |
| running_avg = 0 |
| prev_sigma = sigma |
|
|
| guidance = cond - uncond |
|
|
| if momentum != 0: |
| if not torch.is_tensor(running_avg): |
| running_avg = guidance |
| else: |
| running_avg = momentum * running_avg + guidance |
| guidance = running_avg |
|
|
| if norm_threshold > 0: |
| guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) |
| scale = torch.minimum( |
| torch.ones_like(guidance_norm), |
| norm_threshold / guidance_norm |
| ) |
| guidance = guidance * scale |
|
|
| guidance_parallel, guidance_orthogonal = project(guidance, cond) |
| modified_guidance = guidance_orthogonal + eta * guidance_parallel |
|
|
| modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale |
|
|
| return [modified_cond, uncond] + args["conds_out"][2:] |
|
|
| m = model.clone() |
| m.set_model_sampler_pre_cfg_function(pre_cfg_function) |
| return io.NodeOutput(m) |
|
|
|
|
| class ApgExtension(ComfyExtension): |
| @override |
| async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| return [ |
| APG, |
| ] |
|
|
| async def comfy_entrypoint() -> ApgExtension: |
| return ApgExtension() |
|
|