File size: 5,795 Bytes
baac5bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
class Store:
def __repr__(self):
keys = sorted(self.__dict__)
items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys)
return "{}({})".format(type(self).__name__, ", ".join(items))
def __eq__(self, other):
return self.__dict__ == other.__dict__
store = Store()
# ==================== Hook into sampling functions for ControlNet ====================
import comfy.samplers
def KSAMPLER_sample(*args, **kwargs):
orig_fn = store.KSAMPLER_sample
extra_args = None
model_options = None
try:
extra_args = kwargs['extra_args'] if 'extra_args' in kwargs else args[3]
model_options = extra_args['model_options']
except Exception: ...
if model_options is not None and 'tiled_diffusion' in model_options and extra_args is not None:
sigmas_ = kwargs['sigmas'] if 'sigmas' in kwargs else args[2]
sigmas_all = model_options.pop('sigmas', None)
sigmas = sigmas_all if sigmas_all is not None else sigmas_
store.sigmas = sigmas
store.model_options = model_options
store.extra_args = extra_args
else:
for attr in ['sigmas', 'model_options', 'extra_args']:
_delattr(store, attr)
return orig_fn(*args, **kwargs)
def KSampler_sample(*args, **kwargs):
orig_fn = store.KSampler_sample
self = args[0]
model_patcher = getattr(self, 'model', None)
model_options = getattr(model_patcher, 'model_options', None)
if model_options is not None and 'tiled_diffusion' in model_options:
sigmas = None
try: sigmas = kwargs['sigmas'] if 'sigmas' in kwargs else args[10]
except Exception: ...
if sigmas is None:
sigmas = getattr(self, 'sigmas', None)
if sigmas is not None:
model_options = model_options.copy()
model_options['sigmas'] = sigmas
self.model.model_options = model_options
return orig_fn(*args, **kwargs)
def get_area_and_mult(*args, **kwargs):
conds = kwargs['conds'] if 'conds' in kwargs else args[0]
if (model_options:=getattr(store, 'model_options', None)) is not None and 'tiled_diffusion' in model_options:
if 'control' in conds:
control = conds['control']
if not hasattr(control, 'get_control_orig'):
control.get_control_orig = control.get_control
control.get_control = lambda *a, **kw: control
else:
if 'control' in conds:
control = conds['control']
if hasattr(control, 'get_control_orig') and control.get_control != control.get_control_orig:
control.get_control = control.get_control_orig
return store.get_area_and_mult(*args, **kwargs)
def _delattr(obj, attr):
try:
if hasattr(obj, attr): delattr(obj, attr)
except Exception: ...
def register_hooks():
patches = [
(comfy.samplers.KSampler, 'sample', KSampler_sample),
(comfy.samplers.KSAMPLER, 'sample', KSAMPLER_sample),
(comfy.samplers, 'get_area_and_mult', get_area_and_mult),
]
for parent, fn_name, fn_patch in patches:
if not hasattr(parent, f"_{fn_name}"):
setattr(store, f"_{fn_name}", getattr(parent, fn_name))
setattr(store, fn_patch.__name__, getattr(parent, fn_name))
setattr(parent, fn_name, fn_patch)
register_hooks()
# ==================== Patch pre_run_control ====================
# Is this necessary anymore?
def pre_run_control(model, conds):
s = model.model_sampling
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
try: x['control'].cleanup()
except Exception: ...
x['control'].pre_run(model, percent_to_timestep_function)
comfy.samplers.pre_run_control = pre_run_control
# ==================== Patch SAG ====================
from math import sqrt
import torch.nn.functional as F
import comfy_extras.nodes_sag
from comfy_extras.nodes_sag import gaussian_blur_2d
def calc_closest_factors(a):
for b in range(int(sqrt(a)), 0, -1):
if a % b == 0:
c = a // b
return (b, c)
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
m = calc_closest_factors(hw1)
mh = max(m) if lh > lw else min(m)
mw = m[1] if mh == m[0] else m[0]
mid_shape = mh, mw
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
comfy_extras.nodes_sag.create_blur_map = create_blur_map
# ==================== Patch Gligen ====================
def _set_position(self, boxes, masks, positive_embeddings):
objs = self.position_net(boxes, masks, positive_embeddings)
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
nonlocal objs
_objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1) if x.shape[0] > objs.shape[0] else objs
return module(x, _objs.to(device=x.device, dtype=x.dtype))
return func
import comfy.gligen
comfy.gligen.Gligen._set_position = _set_position
|