dikdimon's picture
Upload negative_rejection_steering using SD-Hub
eaaf420 verified
import logging
import torch
from enum import Enum, auto
from typing import Any
class PredictionType(Enum):
EPS = auto() # ε-prediction
V = auto() # v-prediction
X0 = auto() # x₀-prediction
UNKNOWN = auto() # couldn’t detect / new scheduler
_RAW_TO_ENUM = {
"eps": PredictionType.EPS,
"epsilon": PredictionType.EPS,
"flux": PredictionType.EPS,
"chroma": PredictionType.EPS,
"v": PredictionType.V,
"v_prediction": PredictionType.V,
"x0": PredictionType.X0,
"sample": PredictionType.X0,
}
class NRS:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"skew": ("FLOAT", {"default": 4.0, "min": -30.0, "max": 30.0, "step": 0.01}),
"stretch": ("FLOAT", {"default": 2.0, "min": -30.0, "max": 30.0, "step": 0.01}),
"squash": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def _get_pred_type(self, model) -> PredictionType:
"""
In order to support Comfy, Forge, and possibly other models
and various loaders.
Walk common wrappers until we find something that looks like a
prediction-type flag, then map it to the enum.
Defaults to EPS if all else fails.
"""
def _canon(p):
if p is None:
return ""
if isinstance(p, bytes):
p = p.decode(errors="ignore")
if isinstance(p, Enum):
p = p.name
return str(p).strip().lower()
# Breadth-first search through a few well-known wrappers.
queue, seen = [model], set()
while queue:
obj = queue.pop(0)
# 1) direct hit on this object ---------------------------------
for attr in ("model_type", "prediction_type", "parameterization"):
p = _canon(getattr(obj, attr, None))
if p:
return _RAW_TO_ENUM.get(p, PredictionType.UNKNOWN)
# 2) enqueue child containers we care about -------------------
for attr in ("model", "diffusion_model", "config", "scheduler", "inner_model", "model_sampling"):
child = getattr(obj, attr, None)
if child is not None and id(child) not in seen:
seen.add(id(child))
queue.append(child)
# 3) default ------------------------------------------------------
return PredictionType.UNKNOWN
def _convert_to_eps_space(self, x_orig, sig_root, sigma, cond, uncond):
x_div = None
eps_cond = cond
eps_uncond = uncond
if self.__pred_type == PredictionType.V:
# v → ε conversion
logging.debug(f"NRS._convert_to_eps_space: generating x_div, eps_cond, and eps_uncond for v-pred")
x_div = x_orig / (sigma ** 2 + 1)
eps_cond = ((x_div - (x_orig - cond)) * sig_root) / (sigma)
eps_uncond = ((x_div - (x_orig - uncond)) * sig_root) / (sigma)
elif self.__pred_type == PredictionType.EPS:
logging.debug(f"NRS._convert_to_eps_space: already in eps, no pre-scale needed")
pass # already in ε space
elif self.__pred_type == PredictionType.X0:
raise NotImplementedError("NRS._convert_to_eps_space: x0-prediction not supported yet.")
else:
raise RuntimeError("NRS._convert_to_eps_space: Could not determine prediction type for this model.")
return x_div, eps_cond, eps_uncond
def _finalize_from_eps_space(self, x_orig, x_div, x_final, sig_root, sigma):
nrs_result = x_final
if self.__pred_type == PredictionType.V:
# ε → v conversion
logging.debug(f"NRS._finalize_from_eps_space: generating cfg_result for v-pred")
nrs_result = x_orig - (x_div - x_final * sigma / sig_root)
elif self.__pred_type == PredictionType.EPS:
# already in ε space
logging.debug(f"NRS._finalize_from_eps_space: already in eps, no post-scale needed")
pass
elif self.__pred_type == PredictionType.X0:
raise NotImplementedError("NRS._finalize_from_eps_space: x0-prediction not supported yet.")
else:
raise RuntimeError("NRS._finalize_from_eps_space: Could not determine prediction type for this model.")
return nrs_result
def _convert_to_v_space(self, x_orig, sig_root, sigma, cond, uncond):
x_div = None
v_cond = cond
v_uncond = uncond
if self.__pred_type == PredictionType.V:
logging.debug(f"NRS._convert_to_v_space: already in v, no pre-scale needed")
pass # already in v space
elif self.__pred_type == PredictionType.EPS:
# ε → v conversion
logging.debug(f"NRS._convert_to_v_space: generating x_div, v_cond, and v_uncond for eps")
x_div = x_orig / (sigma ** 2 + 1)
factor = sigma / sig_root
v_cond = x_orig - (x_div - cond * factor)
v_uncond = x_orig - (x_div - uncond * factor)
elif self.__pred_type == PredictionType.X0:
raise NotImplementedError("NRS._convert_to_v_space: x0-prediction not supported yet.")
else:
raise RuntimeError("NRS._convert_to_v_space: Could not determine prediction type for this model.")
return x_div, v_cond, v_uncond
def _finalize_from_v_space(self, x_orig, x_div, x_final, sig_root, sigma):
nrs_result = x_final
if self.__pred_type == PredictionType.V:
# already in v space
logging.debug(f"NRS._finalize_from_v_space: already in v, no post-scale needed")
pass
elif self.__pred_type == PredictionType.EPS:
# v → ε conversion
logging.debug(f"NRS._finalize_from_v_space: generating cfg_result for eps")
nrs_result = (x_div - (x_orig - x_final)) * (sig_root / sigma)
elif self.__pred_type == PredictionType.X0:
raise NotImplementedError("NRS._finalize_from_v_space: x0-prediction not supported yet.")
else:
raise RuntimeError("NRS._finalize_from_v_space: Could not determine prediction type for this model.")
return nrs_result
def patch(self, model, skew, stretch, squash):
self.__pred_type = self._get_pred_type(model) if not hasattr(self, "__pred_type") else self.__pred_type
self.__OPERATION_SPACE = PredictionType.V
def nrs(args):
logging.debug(f"NRS.nrs: Skew: {skew}, Stretch: {stretch}, Squash: {squash}")
# self.__pred_type = self.__pred_type if self.__pred_type is not None else self._get_pred_type(model)
cond = args["cond"]
uncond = args["uncond"]
x_orig = args["input"]
sigma = args["sigma"]
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
sig_root = (sigma ** 2 + 1).sqrt()
x_div, nrs_cond, nrs_uncond = None, None, None
match self.__OPERATION_SPACE:
case PredictionType.V:
x_div, nrs_cond, nrs_uncond = self._convert_to_v_space(x_orig, sig_root, sigma, cond, uncond)
case PredictionType.EPS:
x_div, nrs_cond, nrs_uncond = self._convert_to_eps_space(x_orig, sig_root, sigma, cond, uncond)
case PredictionType.X0:
raise RuntimeError("NRS.nrs: x0-prediction not supported yet.")
case PredictionType.UNKNOWN:
raise RuntimeError("NRS.nrs: Could not determine prediction type for this operation.")
case _:
raise RuntimeError("NRS.nrs: Invalid PredictionType used.")
def _dot(a, b):
return (a*b).sum(dim=1, keepdim=True) # [B,C,W,H] => [B,1,W,H]
def _nrm2(v):
return _dot(v, v)
eps = torch.finfo(nrs_cond.dtype).eps
c_dot_c = _nrm2(nrs_cond) + eps # [B,1,W,H]
u_dot_c = _dot(nrs_uncond, nrs_cond) # [B,1,W,H]
u_on_c = (u_dot_c / c_dot_c) * nrs_cond # [B,1,W,H] * [B,C,H,W]
# Amplify Cond based on length compared to projection of uncond
proj_diff = nrs_cond - u_on_c
stretched = nrs_cond + (stretch * proj_diff)
# Skew/Steer Conf based on rejection of uncond on cond
u_rej_c = nrs_uncond - u_on_c
skewed = stretched - (skew * u_rej_c)
# Squash final length back down to original length of cond
cond_len = nrs_cond.norm(dim=1, keepdim=True)
nrs_len = skewed.norm(dim=1, keepdim=True) + eps
squash_scale = (1 - squash) + (squash * (cond_len / nrs_len))
x_final = skewed * squash_scale
match self.__OPERATION_SPACE:
case PredictionType.V:
return self._finalize_from_v_space(x_orig, x_div, x_final, sig_root, sigma)
case PredictionType.EPS:
return self._finalize_from_eps_space(x_orig, x_div, x_final, sig_root, sigma)
case PredictionType.X0:
raise RuntimeError("NRS.nrs: x0-prediction not supported yet.")
case PredictionType.UNKNOWN:
raise RuntimeError("NRS.nrs: Could not determine prediction type for this operation.")
case _:
raise RuntimeError("NRS.nrs: Invalid PredictionType used.")
m = model.clone()
m.set_model_sampler_cfg_function(nrs, True)
return (m, )
NODE_CLASS_MAPPINGS = {
"NRS": NRS,
}