from typing import Callable import math import torch from torch import Tensor from torch.nn.functional import group_norm from einops import rearrange import comfy.ldm.modules.attention as attention from comfy.ldm.modules.diffusionmodules import openaimodel import comfy.model_management as model_management import comfy.samplers as comfy_samplers import comfy.sample as comfy_sample import comfy.utils from comfy.controlnet import ControlBase from .context import get_context_scheduler from .motion_utils import GroupNormAD, NoiseType from .model_utils import ModelTypeSD, wrap_function_to_inject_xformers_bug_info from .model_injection import InjectionParams, ModelPatcherAndInjector, MotionModelPatcher from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion, VanillaTemporalModule from .logger import logger ################################################################################## ###################################################################### # Global variable to use to more conveniently hack variable access into samplers class AnimateDiffHelper_GlobalState: def __init__(self): self.motion_model: MotionModelPatcher = None self.params: InjectionParams = None self.reset() def reset(self): self.start_step: int = 0 self.last_step: int = 0 self.current_step: int = 0 self.total_steps: int = 0 self.video_length: int = 0 self.context_frames: int = None self.context_stride: int = None self.context_overlap: int = None self.context_schedule: str = None self.closed_loop: bool = False self.sync_context_to_pe: bool = False self.sub_idxs: list = None if self.motion_model is not None: del self.motion_model self.motion_model = None if self.params is not None: del self.params self.params = None def update_with_inject_params(self, params: InjectionParams): self.video_length = params.video_length self.context_frames = params.context_length self.context_stride = params.context_stride self.context_overlap = params.context_overlap self.context_schedule = params.context_schedule self.closed_loop = params.closed_loop self.sync_context_to_pe = params.sync_context_to_pe self.params = params def is_using_sliding_context(self): return self.context_frames is not None def create_exposed_params(self): # This dict will be exposed to be used by other extensions # DO NOT change any of the key names # or I will find you 👁.👁 return { "full_length": self.video_length, "context_length": self.context_frames, "sub_idxs": self.sub_idxs, } ADGS = AnimateDiffHelper_GlobalState() ###################################################################### ################################################################################## ################################################################################## #### Code Injection ################################################## # refer to forward_timestep_embed in comfy/ldm/modules/diffusionmodules/openaimodel.py def forward_timestep_embed_factory() -> Callable: if hasattr(attention, "SpatialVideoTransformer"): def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: if isinstance(layer, openaimodel.VideoResBlock): x = layer(x, emb, num_video_frames, image_only_indicator) elif isinstance(layer, openaimodel.TimestepBlock): x = layer(x, emb) elif isinstance(layer, VanillaTemporalModule): x = layer(x, context) elif isinstance(layer, attention.SpatialVideoTransformer): x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) if "transformer_index" in transformer_options: transformer_options["transformer_index"] += 1 if "current_index" in transformer_options: # keep this for backward compat, for now transformer_options["current_index"] += 1 elif isinstance(layer, attention.SpatialTransformer): x = layer(x, context, transformer_options) if "transformer_index" in transformer_options: transformer_options["transformer_index"] += 1 if "current_index" in transformer_options: # keep this for backward compat, for now transformer_options["current_index"] += 1 elif isinstance(layer, openaimodel.Upsample): x = layer(x, output_shape=output_shape) else: x = layer(x) return x # keep old version for backwards compatibility (TODO: remove at end of 2023) else: def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): for layer in ts: if isinstance(layer, openaimodel.TimestepBlock): x = layer(x, emb) elif isinstance(layer, VanillaTemporalModule): x = layer(x, context) elif isinstance(layer, attention.SpatialTransformer): x = layer(x, context, transformer_options) if "current_index" in transformer_options: transformer_options["current_index"] += 1 elif isinstance(layer, openaimodel.Upsample): x = layer(x, output_shape=output_shape) else: x = layer(x) return x return forward_timestep_embed def unlimited_memory_required(*args, **kwargs): return 0 def groupnorm_mm_factory(params: InjectionParams): def groupnorm_mm_forward(self, input: Tensor) -> Tensor: # axes_factor normalizes batch based on total conds and unconds passed in batch; # the conds and unconds per batch can change based on VRAM optimizations that may kick in if not ADGS.is_using_sliding_context(): axes_factor = input.size(0)//params.video_length else: axes_factor = input.size(0)//params.context_length input = rearrange(input, "(b f) c h w -> b c f h w", b=axes_factor) input = group_norm(input, self.num_groups, self.weight, self.bias, self.eps) input = rearrange(input, "b c f h w -> (b f) c h w", b=axes_factor) return input return groupnorm_mm_forward def get_additional_models_factory(orig_get_additional_models: Callable, motion_model: MotionModelPatcher): def get_additional_models_with_motion(*args, **kwargs): models, inference_memory = orig_get_additional_models(*args, **kwargs) models.append(motion_model) # TODO: account for inference memory as well? return models, inference_memory return get_additional_models_with_motion ###################################################################### ################################################################################## def prepare_mask_ad(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") #noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = noise_mask.to(device) return noise_mask def apply_params_to_motion_model(motion_model: MotionModelPatcher, params: InjectionParams): if params.context_length and params.video_length > params.context_length: logger.info(f"Sliding context window activated - latents passed in ({params.video_length}) greater than context_length {params.context_length}.") else: logger.info(f"Regular AnimateDiff activated - latents passed in ({params.video_length}) less or equal to context_length {params.context_length}.") params.reset_context() # if no context_length, treat video length as intended AD frame window if not params.context_length: if params.video_length > motion_model.model.encoding_max_len: raise ValueError(f"Without a context window, AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames, but received {params.video_length} latents.") motion_model.model.set_video_length(params.video_length, params.full_length) # otherwise, treat context_length as intended AD frame window else: if params.context_length > motion_model.model.encoding_max_len: raise ValueError(f"AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames for a context window, but received context length of {params.context_length}.") motion_model.model.set_video_length(params.context_length, params.full_length) # inject model logger.info(f"Using motion module {motion_model.model.mm_info.mm_name} version {motion_model.model.mm_info.mm_version}.") class FunctionInjectionHolder: def __init__(self): pass def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionParams): # Save Original Functions self.orig_forward_timestep_embed = openaimodel.forward_timestep_embed # needed to account for VanillaTemporalModule self.orig_memory_required = model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames self.orig_groupnormad_forward = GroupNormAD.forward self.orig_sampling_function = comfy_samplers.sampling_function # used to support sliding context windows in samplers self.orig_prepare_mask = comfy_sample.prepare_mask self.orig_get_additional_models = comfy_sample.get_additional_models # Inject Functions openaimodel.forward_timestep_embed = forward_timestep_embed_factory() if params.unlimited_area_hack: model.model.memory_required = unlimited_memory_required # only apply groupnorm hack if not [v3 or (AnimateDiff SD1.5 and v2 and should apply v2 properly)] info: AnimateDiffInfo = model.motion_model.model.mm_info if not (info.mm_version == AnimateDiffVersion.V3 or (info.mm_format == AnimateDiffFormat.ANIMATEDIFF and info.sd_type == ModelTypeSD.SD1_5 and info.mm_version == AnimateDiffVersion.V2 and params.apply_v2_models_properly)): torch.nn.GroupNorm.forward = groupnorm_mm_factory(params) if params.apply_mm_groupnorm_hack: GroupNormAD.forward = groupnorm_mm_factory(params) comfy_samplers.sampling_function = sliding_sampling_function comfy_sample.prepare_mask = prepare_mask_ad comfy_sample.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_model) del info def restore_functions(self, model: ModelPatcherAndInjector): # Restoration try: model.model.memory_required = self.orig_memory_required openaimodel.forward_timestep_embed = self.orig_forward_timestep_embed torch.nn.GroupNorm.forward = self.orig_groupnorm_forward GroupNormAD.forward = self.orig_groupnormad_forward comfy_samplers.sampling_function = self.orig_sampling_function comfy_sample.prepare_mask = self.orig_prepare_mask comfy_sample.get_additional_models = self.orig_get_additional_models except AttributeError: logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \ "to save original functions before injection, and a more specific error was thrown by ComfyUI.") def motion_sample_factory(orig_comfy_sample: Callable) -> Callable: def motion_sample(model: ModelPatcherAndInjector, noise: Tensor, *args, **kwargs): # check if model is intended for injecting if type(model) != ModelPatcherAndInjector: return orig_comfy_sample(model, noise, *args, **kwargs) # otherwise, injection time latents = None function_injections = FunctionInjectionHolder() try: # clone params from model params = model.motion_injection_params.clone() # get amount of latents passed in, and store in params latents = args[-1] params.video_length = latents.size(0) params.full_length = latents.size(0) # reset global state ADGS.reset() # store and inject functions function_injections.inject_functions(model, params) # apply custom noise, if needed disable_noise = kwargs.get("disable_noise") or False seed = kwargs["seed"] if not disable_noise: # if context asks for specific noise, do it noise = NoiseType.prepare_noise(params.noise_type, latents=latents, noise=noise, context_length=params.context_length, seed=seed) # apply params to motion model apply_params_to_motion_model(model.motion_model, params) # handle GLOBALSTATE vars and step tally ADGS.update_with_inject_params(params) ADGS.start_step = kwargs.get("start_step") or 0 ADGS.current_step = ADGS.start_step ADGS.last_step = kwargs.get("last_step") or 0 original_callback = kwargs.get("callback", None) def ad_callback(step, x0, x, total_steps): if original_callback is not None: original_callback(step, x0, x, total_steps) # update GLOBALSTATE for next iteration ADGS.current_step = ADGS.start_step + step + 1 kwargs["callback"] = ad_callback ADGS.motion_model = model.motion_model model.motion_model.pre_run() return wrap_function_to_inject_xformers_bug_info(orig_comfy_sample)(model, noise, *args, **kwargs) finally: del latents del noise # reset global state ADGS.reset() # restore injected functions function_injections.restore_functions(model) del function_injections return motion_sample def sliding_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 if 'timestep_start' in conds: timestep_start = conds['timestep_start'] if timestep_in[0] > timestep_start: return None if 'timestep_end' in conds: timestep_end = conds['timestep_end'] if timestep_in[0] < timestep_end: return None if 'area' in conds: area = conds['area'] if 'strength' in conds: strength = conds['strength'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] if 'mask' in conds: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process mask_strength = 1.0 if "mask_strength" in conds: mask_strength = conds["mask_strength"] mask = conds['mask'] assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) mult = mask * strength if 'mask' not in conds: rr = 8 if area[2] != 0: for t in range(rr): mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) if (area[0] + area[2]) < x_in.shape[2]: for t in range(rr): mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) if area[3] != 0: for t in range(rr): mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) if (area[1] + area[3]) < x_in.shape[3]: for t in range(rr): mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) conditionning = {} model_conds = conds["model_conds"] for c in model_conds: conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) control = None if 'control' in conds: control = conds['control'] patches = None if 'gligen' in conds: gligen = conds['gligen'] patches = {} gligen_type = gligen[0] gligen_model = gligen[1] if gligen_type == "position": gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) else: gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) patches['middle_patch'] = [gligen_patch] return (input_x, mult, conditionning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: return True if c1.keys() != c2.keys(): return False for k in c1: if not c1[k].can_concat(c2[k]): return False return True def can_concat_cond(c1, c2): if c1[0].shape != c2[0].shape: return False #control if (c1[4] is None) != (c2[4] is None): return False if c1[4] is not None: if c1[4] is not c2[4]: return False #patches if (c1[5] is None) != (c2[5] is None): return False if (c1[5] is not None): if c1[5] is not c2[5]: return False return cond_equal_size(c1[2], c2[2]) def cond_cat(c_list): c_crossattn = [] c_concat = [] c_adm = [] crossattn_max_len = 0 temp = {} for x in c_list: for k in x: cur = temp.get(k, []) cur.append(x[k]) temp[k] = cur out = {} for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) return out def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in) * 1e-37 out_uncond = torch.zeros_like(x_in) out_uncond_count = torch.ones_like(x_in) * 1e-37 COND = 0 UNCOND = 1 to_run = [] for x in cond: p = get_area_and_mult(x, x_in, timestep) if p is None: continue to_run += [(p, COND)] if uncond is not None: for x in uncond: p = get_area_and_mult(x, x_in, timestep) if p is None: continue to_run += [(p, UNCOND)] while len(to_run) > 0: first = to_run[0] first_shape = first[0][0].shape to_batch_temp = [] for x in range(len(to_run)): if can_concat_cond(to_run[x][0], first[0]): to_batch_temp += [x] to_batch_temp.reverse() to_batch = to_batch_temp[:1] free_memory = model_management.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] if model.memory_required(input_shape) < free_memory: to_batch = batch_amount break input_x = [] mult = [] c = [] cond_or_uncond = [] area = [] control = None patches = None for x in to_batch: o = to_run.pop(x) p = o[0] input_x += [p[0]] mult += [p[1]] c += [p[2]] area += [p[3]] cond_or_uncond += [o[1]] control = p[4] patches = p[5] batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) c = cond_cat(c) timestep_ = torch.cat([timestep] * batch_chunks) if control is not None: c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) transformer_options = {} if 'transformer_options' in model_options: transformer_options = model_options['transformer_options'].copy() if patches is not None: if "patches" in transformer_options: cur_patches = transformer_options["patches"].copy() for p in patches: if p in cur_patches: cur_patches[p] = cur_patches[p] + patches[p] else: cur_patches[p] = patches[p] else: transformer_options["patches"] = patches transformer_options["cond_or_uncond"] = cond_or_uncond transformer_options["sigmas"] = timestep transformer_options["ad_params"] = ADGS.create_exposed_params() c['transformer_options'] = transformer_options if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) del input_x for o in range(batch_chunks): if cond_or_uncond[o] == COND: out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] else: out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] del mult out_cond /= out_count del out_count out_uncond /= out_uncond_count del out_uncond_count return out_cond, out_uncond # sliding_calc_cond_uncond_batch inspired by ashen's initial hack for 16-frame sliding context: # https://github.com/comfyanonymous/ComfyUI/compare/master...ashen-sensored:ComfyUI:master def sliding_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): # get context scheduler context_scheduler = get_context_scheduler(ADGS.context_schedule) # figure out how input is split axes_factor = x_in.size(0)//ADGS.video_length # prepare final cond, uncond, and out_count cond_final = torch.zeros_like(x_in) uncond_final = torch.zeros_like(x_in) out_count_final = torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) def prepare_control_objects(control: ControlBase, full_idxs: list[int]): if control.previous_controlnet is not None: prepare_control_objects(control.previous_controlnet, full_idxs) control.sub_idxs = full_idxs control.full_latent_length = ADGS.video_length control.context_length = ADGS.context_frames def get_resized_cond(cond_in, full_idxs) -> list: # reuse or resize cond items to match context requirements resized_cond = [] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: resized_actual_cond = actual_cond.copy() # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary for key in actual_cond: try: cond_item = actual_cond[key] if isinstance(cond_item, Tensor): # check that tensor is the expected length - x.size(0) if cond_item.size(0) == x_in.size(0): # if so, it's subsetting time - tell controls the expected indeces so they can handle them actual_cond_item = cond_item[full_idxs] resized_actual_cond[key] = actual_cond_item else: resized_actual_cond[key] = cond_item # look for control elif key == "control": control_item = cond_item if hasattr(control_item, "sub_idxs"): prepare_control_objects(control_item, full_idxs) else: raise ValueError(f"Control type {type(control_item).__name__} may not support required features for sliding context window; \ use Control objects from Kosinkadink/Advanced-ControlNet nodes, or make sure Advanced-ControlNet is updated.") resized_actual_cond[key] = control_item del control_item elif isinstance(cond_item, dict): new_cond_item = cond_item.copy() # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) for cond_key, cond_value in new_cond_item.items(): if isinstance(cond_value, Tensor): if cond_value.size(0) == x_in.size(0): new_cond_item[cond_key] = cond_value[full_idxs] # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, Tensor): if cond_value.cond.size(0) == x_in.size(0): new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond[full_idxs]) resized_actual_cond[key] = new_cond_item else: resized_actual_cond[key] = cond_item finally: del cond_item # just in case to prevent VRAM issues resized_cond.append(resized_actual_cond) return resized_cond # perform calc_cond_uncond_batch per context window for ctx_idxs in context_scheduler(ADGS.current_step, ADGS.total_steps, ADGS.video_length, ADGS.context_frames, ADGS.context_stride, ADGS.context_overlap, ADGS.closed_loop): ADGS.sub_idxs = ctx_idxs ADGS.params.sub_idxs = ADGS.sub_idxs ADGS.motion_model.model.set_sub_idxs(ADGS.sub_idxs) # account for all portions of input frames full_idxs = [] for n in range(axes_factor): for ind in ctx_idxs: full_idxs.append((ADGS.video_length*n)+ind) # get subsections of x, timestep, cond, uncond, cond_concat sub_x = x_in[full_idxs] sub_timestep = timestep[full_idxs] sub_cond = get_resized_cond(cond, full_idxs) if cond is not None else None sub_uncond = get_resized_cond(uncond, full_idxs) if uncond is not None else None sub_cond_out, sub_uncond_out = calc_cond_uncond_batch(model, sub_cond, sub_uncond, sub_x, sub_timestep, model_options) cond_final[full_idxs] += sub_cond_out uncond_final[full_idxs] += sub_uncond_out out_count_final[full_idxs] += 1 # increment which indeces were used # normalize cond and uncond via division by context usage counts cond_final /= out_count_final uncond_final /= out_count_final del out_count_final return cond_final, uncond_final if math.isclose(cond_scale, 1.0): uncond = None if not ADGS.is_using_sliding_context(): cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) else: cond, uncond = sliding_calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) if "sampler_cfg_function" in model_options: args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} return x - model_options["sampler_cfg_function"](args) else: return uncond + (cond - uncond) * cond_scale