| | |
| |
|
| |
|
| | from lib_dynamic_thresholding.dynthres_core import DynThresh |
| |
|
| |
|
| | class DynamicThresholdingNode: |
| |
|
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "model": ("MODEL",), |
| | "mimic_scale": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step": 0.5}), |
| | "threshold_percentile": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
| | "mimic_mode": (DynThresh.Modes, ), |
| | "mimic_scale_min": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.5}), |
| | "cfg_mode": (DynThresh.Modes, ), |
| | "cfg_scale_min": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.5}), |
| | "sched_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), |
| | "separate_feature_channels": (["enable", "disable"], ), |
| | "scaling_startpoint": (DynThresh.Startpoints, ), |
| | "variability_measure": (DynThresh.Variabilities, ), |
| | "interpolate_phi": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("MODEL",) |
| | FUNCTION = "patch" |
| | CATEGORY = "advanced/mcmonkey" |
| |
|
| | def patch(self, model, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi): |
| |
|
| | dynamic_thresh = DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, 0, 999, separate_feature_channels == "enable", scaling_startpoint, variability_measure, interpolate_phi) |
| | |
| | def sampler_dyn_thresh(args): |
| | input = args["input"] |
| | cond = input - args["cond"] |
| | uncond = input - args["uncond"] |
| | cond_scale = args["cond_scale"] |
| | time_step = model.model.model_sampling.timestep(args["sigma"]) |
| | time_step = time_step[0].item() |
| | dynamic_thresh.step = 999 - time_step |
| |
|
| | return input - dynamic_thresh.dynthresh(cond, uncond, cond_scale, None) |
| |
|
| | m = model.clone() |
| | m.set_model_sampler_cfg_function(sampler_dyn_thresh) |
| | return (m, ) |
| |
|