Spaces:
Paused
Paused
| store = {} | |
| # ==================== Hook into sampling functions for ControlNet ==================== | |
| import comfy.samplers | |
| def patch1(fn_name): | |
| def calc_cond_batch(*args, **kwargs): | |
| x_in = kwargs['x_in'] if 'x_in' in kwargs else args[2] | |
| model_options = kwargs['model_options'] if 'model_options' in kwargs else args[4] | |
| if not hasattr(x_in, 'model_options'): | |
| x_in.model_options = model_options | |
| return store[fn_name](*args, **kwargs) | |
| return calc_cond_batch | |
| def patch2(fn_name): | |
| def get_area_and_mult(*args, **kwargs): | |
| x_in = kwargs['x_in'] if 'x_in' in kwargs else args[1] | |
| conds = kwargs['conds'] if 'conds' in kwargs else args[0] | |
| if (model_options:=getattr(x_in, 'model_options', None)) is not None and 'tiled_diffusion' in model_options: | |
| if 'control' in conds: | |
| control = conds['control'] | |
| if not hasattr(control, 'get_control_orig'): | |
| control.get_control_orig = control.get_control | |
| control.get_control = lambda *a, **kw: control | |
| else: | |
| if 'control' in conds: | |
| control = conds['control'] | |
| if hasattr(control, 'get_control_orig') and control.get_control != control.get_control_orig: | |
| control.get_control = control.get_control_orig | |
| return store[fn_name](*args, **kwargs) | |
| return get_area_and_mult | |
| patches = [ | |
| (comfy.samplers, 'calc_cond_batch', patch1), | |
| (comfy.samplers, 'get_area_and_mult', patch2), | |
| ] | |
| for parent, fn_name, create_patch in patches: | |
| store[fn_name] = getattr(parent, fn_name) | |
| setattr(parent, fn_name, create_patch(fn_name)) | |
| # ==================== Patch pre_run_control ==================== | |
| # Is this necessary anymore? | |
| def pre_run_control(model, conds): | |
| s = model.model_sampling | |
| for t in range(len(conds)): | |
| x = conds[t] | |
| timestep_start = None | |
| timestep_end = None | |
| percent_to_timestep_function = lambda a: s.percent_to_sigma(a) | |
| if 'control' in x: | |
| try: x['control'].cleanup() | |
| except Exception: ... | |
| x['control'].pre_run(model, percent_to_timestep_function) | |
| comfy.samplers.pre_run_control = pre_run_control | |
| # ==================== Patch SAG ==================== | |
| import math | |
| import torch.nn.functional as F | |
| import comfy_extras.nodes_sag | |
| from comfy_extras.nodes_sag import gaussian_blur_2d | |
| def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): | |
| # reshape and GAP the attention map | |
| _, hw1, hw2 = attn.shape | |
| b, _, lh, lw = x0.shape | |
| attn = attn.reshape(b, -1, hw1, hw2) | |
| # Global Average Pool | |
| mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold | |
| def calc_closest_factors(a): | |
| for b in range(int(math.sqrt(a)), 0, -1): | |
| if a % b == 0: | |
| c = a // b | |
| return (b,c) | |
| m = calc_closest_factors(hw1) | |
| mh = max(m) if lh > lw else min(m) | |
| mw = m[1] if mh == m[0] else m[0] | |
| mid_shape = mh, mw | |
| # Reshape | |
| mask = ( | |
| mask.reshape(b, *mid_shape) | |
| .unsqueeze(1) | |
| .type(attn.dtype) | |
| ) | |
| # Upsample | |
| mask = F.interpolate(mask, (lh, lw)) | |
| blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) | |
| blurred = blurred * mask + x0 * (1 - mask) | |
| return blurred | |
| comfy_extras.nodes_sag.create_blur_map = create_blur_map | |
| # ==================== Patch Gligen ==================== | |
| def _set_position(self, boxes, masks, positive_embeddings): | |
| objs = self.position_net(boxes, masks, positive_embeddings) | |
| def func(x, extra_options): | |
| key = extra_options["transformer_index"] | |
| module = self.module_list[key] | |
| nonlocal objs | |
| _objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1) if x.shape[0] > objs.shape[0] else objs | |
| return module(x, _objs.to(device=x.device, dtype=x.dtype)) | |
| return func | |
| import comfy.gligen | |
| comfy.gligen.Gligen._set_position = _set_position | |