| | import logging |
| | import torch |
| | from enum import Enum, auto |
| | from typing import Any |
| |
|
| |
|
| | class PredictionType(Enum): |
| | EPS = auto() |
| | V = auto() |
| | X0 = auto() |
| | UNKNOWN = auto() |
| |
|
| | _RAW_TO_ENUM = { |
| | "eps": PredictionType.EPS, |
| | "epsilon": PredictionType.EPS, |
| | "flux": PredictionType.EPS, |
| | "chroma": PredictionType.EPS, |
| | "v": PredictionType.V, |
| | "v_prediction": PredictionType.V, |
| | "x0": PredictionType.X0, |
| | "sample": PredictionType.X0, |
| | } |
| |
|
| | class NRS: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return {"required": { "model": ("MODEL",), |
| | "skew": ("FLOAT", {"default": 4.0, "min": -30.0, "max": 30.0, "step": 0.01}), |
| | "stretch": ("FLOAT", {"default": 2.0, "min": -30.0, "max": 30.0, "step": 0.01}), |
| | "squash": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
| | }} |
| | RETURN_TYPES = ("MODEL",) |
| | FUNCTION = "patch" |
| |
|
| | CATEGORY = "advanced/model" |
| |
|
| | def _get_pred_type(self, model) -> PredictionType: |
| | """ |
| | In order to support Comfy, Forge, and possibly other models |
| | and various loaders. |
| | Walk common wrappers until we find something that looks like a |
| | prediction-type flag, then map it to the enum. |
| | Defaults to EPS if all else fails. |
| | """ |
| | def _canon(p): |
| | if p is None: |
| | return "" |
| | if isinstance(p, bytes): |
| | p = p.decode(errors="ignore") |
| | if isinstance(p, Enum): |
| | p = p.name |
| | return str(p).strip().lower() |
| |
|
| | |
| | queue, seen = [model], set() |
| |
|
| | while queue: |
| | obj = queue.pop(0) |
| |
|
| | |
| | for attr in ("model_type", "prediction_type", "parameterization"): |
| | p = _canon(getattr(obj, attr, None)) |
| | if p: |
| | return _RAW_TO_ENUM.get(p, PredictionType.UNKNOWN) |
| |
|
| | |
| | for attr in ("model", "diffusion_model", "config", "scheduler", "inner_model", "model_sampling"): |
| | child = getattr(obj, attr, None) |
| | if child is not None and id(child) not in seen: |
| | seen.add(id(child)) |
| | queue.append(child) |
| |
|
| | |
| | return PredictionType.UNKNOWN |
| | |
| | def _convert_to_eps_space(self, x_orig, sig_root, sigma, cond, uncond): |
| | x_div = None |
| | eps_cond = cond |
| | eps_uncond = uncond |
| | if self.__pred_type == PredictionType.V: |
| | |
| | logging.debug(f"NRS._convert_to_eps_space: generating x_div, eps_cond, and eps_uncond for v-pred") |
| | x_div = x_orig / (sigma ** 2 + 1) |
| |
|
| | eps_cond = ((x_div - (x_orig - cond)) * sig_root) / (sigma) |
| | eps_uncond = ((x_div - (x_orig - uncond)) * sig_root) / (sigma) |
| | elif self.__pred_type == PredictionType.EPS: |
| | logging.debug(f"NRS._convert_to_eps_space: already in eps, no pre-scale needed") |
| | pass |
| | elif self.__pred_type == PredictionType.X0: |
| | raise NotImplementedError("NRS._convert_to_eps_space: x0-prediction not supported yet.") |
| | else: |
| | raise RuntimeError("NRS._convert_to_eps_space: Could not determine prediction type for this model.") |
| | |
| | return x_div, eps_cond, eps_uncond |
| |
|
| | def _finalize_from_eps_space(self, x_orig, x_div, x_final, sig_root, sigma): |
| | nrs_result = x_final |
| | if self.__pred_type == PredictionType.V: |
| | |
| | logging.debug(f"NRS._finalize_from_eps_space: generating cfg_result for v-pred") |
| | nrs_result = x_orig - (x_div - x_final * sigma / sig_root) |
| | elif self.__pred_type == PredictionType.EPS: |
| | |
| | logging.debug(f"NRS._finalize_from_eps_space: already in eps, no post-scale needed") |
| | pass |
| | elif self.__pred_type == PredictionType.X0: |
| | raise NotImplementedError("NRS._finalize_from_eps_space: x0-prediction not supported yet.") |
| | else: |
| | raise RuntimeError("NRS._finalize_from_eps_space: Could not determine prediction type for this model.") |
| | return nrs_result |
| | |
| | def _convert_to_v_space(self, x_orig, sig_root, sigma, cond, uncond): |
| | x_div = None |
| | v_cond = cond |
| | v_uncond = uncond |
| | if self.__pred_type == PredictionType.V: |
| | logging.debug(f"NRS._convert_to_v_space: already in v, no pre-scale needed") |
| | pass |
| | elif self.__pred_type == PredictionType.EPS: |
| | |
| | logging.debug(f"NRS._convert_to_v_space: generating x_div, v_cond, and v_uncond for eps") |
| | x_div = x_orig / (sigma ** 2 + 1) |
| | factor = sigma / sig_root |
| |
|
| | v_cond = x_orig - (x_div - cond * factor) |
| | v_uncond = x_orig - (x_div - uncond * factor) |
| | elif self.__pred_type == PredictionType.X0: |
| | raise NotImplementedError("NRS._convert_to_v_space: x0-prediction not supported yet.") |
| | else: |
| | raise RuntimeError("NRS._convert_to_v_space: Could not determine prediction type for this model.") |
| | |
| | return x_div, v_cond, v_uncond |
| |
|
| | def _finalize_from_v_space(self, x_orig, x_div, x_final, sig_root, sigma): |
| | nrs_result = x_final |
| | if self.__pred_type == PredictionType.V: |
| | |
| | logging.debug(f"NRS._finalize_from_v_space: already in v, no post-scale needed") |
| | pass |
| | elif self.__pred_type == PredictionType.EPS: |
| | |
| | logging.debug(f"NRS._finalize_from_v_space: generating cfg_result for eps") |
| | nrs_result = (x_div - (x_orig - x_final)) * (sig_root / sigma) |
| | elif self.__pred_type == PredictionType.X0: |
| | raise NotImplementedError("NRS._finalize_from_v_space: x0-prediction not supported yet.") |
| | else: |
| | raise RuntimeError("NRS._finalize_from_v_space: Could not determine prediction type for this model.") |
| | return nrs_result |
| | |
| | def patch(self, model, skew, stretch, squash): |
| | self.__pred_type = self._get_pred_type(model) if not hasattr(self, "__pred_type") else self.__pred_type |
| | self.__OPERATION_SPACE = PredictionType.V |
| |
|
| | def nrs(args): |
| | logging.debug(f"NRS.nrs: Skew: {skew}, Stretch: {stretch}, Squash: {squash}") |
| | |
| | cond = args["cond"] |
| | uncond = args["uncond"] |
| | x_orig = args["input"] |
| | |
| | sigma = args["sigma"] |
| | sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) |
| | sig_root = (sigma ** 2 + 1).sqrt() |
| | |
| | x_div, nrs_cond, nrs_uncond = None, None, None |
| | match self.__OPERATION_SPACE: |
| | case PredictionType.V: |
| | x_div, nrs_cond, nrs_uncond = self._convert_to_v_space(x_orig, sig_root, sigma, cond, uncond) |
| | case PredictionType.EPS: |
| | x_div, nrs_cond, nrs_uncond = self._convert_to_eps_space(x_orig, sig_root, sigma, cond, uncond) |
| | case PredictionType.X0: |
| | raise RuntimeError("NRS.nrs: x0-prediction not supported yet.") |
| | case PredictionType.UNKNOWN: |
| | raise RuntimeError("NRS.nrs: Could not determine prediction type for this operation.") |
| | case _: |
| | raise RuntimeError("NRS.nrs: Invalid PredictionType used.") |
| |
|
| | def _dot(a, b): |
| | return (a*b).sum(dim=1, keepdim=True) |
| |
|
| | def _nrm2(v): |
| | return _dot(v, v) |
| |
|
| | eps = torch.finfo(nrs_cond.dtype).eps |
| | c_dot_c = _nrm2(nrs_cond) + eps |
| | u_dot_c = _dot(nrs_uncond, nrs_cond) |
| | u_on_c = (u_dot_c / c_dot_c) * nrs_cond |
| | |
| | |
| | proj_diff = nrs_cond - u_on_c |
| | stretched = nrs_cond + (stretch * proj_diff) |
| |
|
| | |
| | u_rej_c = nrs_uncond - u_on_c |
| | skewed = stretched - (skew * u_rej_c) |
| |
|
| | |
| | cond_len = nrs_cond.norm(dim=1, keepdim=True) |
| | nrs_len = skewed.norm(dim=1, keepdim=True) + eps |
| |
|
| | squash_scale = (1 - squash) + (squash * (cond_len / nrs_len)) |
| | x_final = skewed * squash_scale |
| |
|
| | match self.__OPERATION_SPACE: |
| | case PredictionType.V: |
| | return self._finalize_from_v_space(x_orig, x_div, x_final, sig_root, sigma) |
| | case PredictionType.EPS: |
| | return self._finalize_from_eps_space(x_orig, x_div, x_final, sig_root, sigma) |
| | case PredictionType.X0: |
| | raise RuntimeError("NRS.nrs: x0-prediction not supported yet.") |
| | case PredictionType.UNKNOWN: |
| | raise RuntimeError("NRS.nrs: Could not determine prediction type for this operation.") |
| | case _: |
| | raise RuntimeError("NRS.nrs: Invalid PredictionType used.") |
| | |
| | m = model.clone() |
| | m.set_model_sampler_cfg_function(nrs, True) |
| | return (m, ) |
| |
|
| | NODE_CLASS_MAPPINGS = { |
| | "NRS": NRS, |
| | } |
| |
|