| import copy |
|
|
| from einops import rearrange |
| from torch import Tensor |
| import torch.nn.functional as F |
| import torch |
|
|
| import comfy.model_management |
| import comfy.utils |
| from comfy.model_patcher import ModelPatcher |
|
|
| from .motion_module_ad import AnimateDiffModel, has_mid_block, normalize_ad_state_dict |
| from .logger import logger |
| from .motion_utils import MotionCompatibilityError, NoiseType, normalize_min_max |
| from .motion_lora import MotionLoraInfo, MotionLoraList |
| from .model_utils import get_motion_lora_path, get_motion_model_path, get_sd_model_type |
|
|
|
|
| |
| |
| class ModelPatcherAndInjector(ModelPatcher): |
| def __init__(self, m: ModelPatcher): |
| |
| super().__init__(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update) |
| self.patches = {} |
| for k in m.patches: |
| self.patches[k] = m.patches[k][:] |
|
|
| self.object_patches = m.object_patches.copy() |
| self.model_options = copy.deepcopy(m.model_options) |
| self.model_keys = m.model_keys |
|
|
| |
| self.motion_injection_params: InjectionParams = None |
| self.motion_model: MotionModelPatcher = None |
| self.motion_model_sampling = None |
| |
| def model_patches_to(self, device): |
| super().model_patches_to(device) |
| if self.motion_model is not None: |
| try: |
| self.motion_model.model.to(device) |
| except Exception: |
| pass |
|
|
| def patch_model(self, device_to=None): |
| |
| patched_model = super().patch_model(device_to) |
| |
| self.inject_model(device_to=device_to) |
| return patched_model |
|
|
| def unpatch_model(self, device_to=None): |
| |
| self.eject_model(device_to=device_to) |
| |
| return super().unpatch_model(device_to) |
|
|
| def inject_model(self, device_to=None): |
| if self.motion_model is not None: |
| self.motion_model.model.eject(self) |
| self.motion_model.model.inject(self) |
| try: |
| self.motion_model.model.to(device_to) |
| except Exception: |
| pass |
|
|
| def eject_model(self, device_to=None): |
| if self.motion_model is not None: |
| self.motion_model.model.eject(self) |
| try: |
| self.motion_model.model.to(device_to) |
| except Exception: |
| pass |
|
|
| def clone(self): |
| cloned = ModelPatcherAndInjector(self) |
| cloned.motion_model = self.motion_model |
| cloned.motion_injection_params = self.motion_injection_params |
| cloned.motion_model_sampling = self.motion_model_sampling |
| return cloned |
|
|
|
|
| class MotionModelPatcher(ModelPatcher): |
| |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.model: AnimateDiffModel = self.model |
|
|
| def patch_model(self, *args, **kwargs): |
| |
| patched_model = super().patch_model(*args, **kwargs) |
| self.prepare_weights() |
| return patched_model |
|
|
| def prepare_weights(self): |
| |
| |
| state_dict = self.model.state_dict() |
| for key in state_dict: |
| weight = comfy.model_management.resolve_lowvram_weight(state_dict[key], self.model, key) |
| try: |
| comfy.utils.set_attr(self.model, key, weight) |
| except Exception: |
| pass |
| |
| def pre_run(self): |
| |
| self.prepare_weights() |
|
|
| def cleanup(self): |
| if self.model is not None: |
| self.model.cleanup() |
|
|
|
|
| def get_vanilla_model_patcher(m: ModelPatcher) -> ModelPatcher: |
| model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update) |
| model.patches = {} |
| for k in m.patches: |
| model.patches[k] = m.patches[k][:] |
|
|
| model.object_patches = m.object_patches.copy() |
| model.model_options = copy.deepcopy(m.model_options) |
| model.model_keys = m.model_keys |
| return model |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def load_motion_lora_as_patches(motion_model: MotionModelPatcher, lora: MotionLoraInfo) -> None: |
| def get_version(has_midblock: bool): |
| return "v2" if has_midblock else "v1" |
|
|
| lora_path = get_motion_lora_path(lora.name) |
| logger.info(f"Loading motion LoRA {lora.name}") |
| state_dict = comfy.utils.load_torch_file(lora_path) |
|
|
| |
| for key in list(state_dict.keys()): |
| if "temporal" not in key: |
| del state_dict[key] |
| if len(state_dict) == 0: |
| raise ValueError(f"'{lora.name}' contains no temporal keys; it is not a valid motion LoRA!") |
|
|
| model_has_midblock = motion_model.model.mid_block != None |
| lora_has_midblock = has_mid_block(state_dict) |
| logger.info(f"Applying a {get_version(lora_has_midblock)} LoRA ({lora.name}) to a { motion_model.model.mm_info.mm_version} motion model.") |
|
|
| patches = {} |
| |
| for key in state_dict: |
| |
| if not model_has_midblock: |
| if "mid_block" in key: continue |
| |
| if "up." in key: continue |
|
|
| |
| up_key = key.replace(".down.", ".up.") |
|
|
| |
| model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") |
| |
| model_key = model_key.replace("to_out.", "to_out.0.") |
| |
| weight_down = state_dict[key] |
| weight_up = state_dict[up_key] |
| |
| |
| patches[model_key] = (torch.mm(weight_up, weight_down),) |
| del state_dict |
| |
| motion_model.add_patches(patches=patches, strength_patch=lora.strength) |
|
|
|
|
| def load_motion_module(model_name: str, model: ModelPatcher, motion_lora: MotionLoraList = None, motion_model_settings: 'MotionModelSettings' = None) -> MotionModelPatcher: |
| model_path = get_motion_model_path(model_name) |
| logger.info(f"Loading motion module {model_name}") |
| mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True) |
| |
| |
| mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name) |
| |
| model_sd_type = get_sd_model_type(model) |
| if model_sd_type != mm_info.sd_type: |
| raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \ |
| + f"but the provided model is type {model_sd_type}.") |
| |
| mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings) |
| |
| ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info) |
| ad_wrapper.to(model.model_dtype()) |
| ad_wrapper.to(model.offload_device) |
| load_result = ad_wrapper.load_state_dict(mm_state_dict) |
| |
| |
| motion_model = MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device) |
| |
| if motion_lora is not None: |
| for lora in motion_lora.loras: |
| load_motion_lora_as_patches(motion_model, lora) |
| return motion_model |
|
|
|
|
| def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int): |
| pe_shape = model_dict[key].shape |
| temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) |
| temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear") |
| temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1) |
| model_dict[key] = temp_pe |
| del temp_pe |
|
|
|
|
| def interpolate_pe_to_length_diffs(model_dict: dict[str, Tensor], key: str, new_length: int): |
| |
| pe_shape = model_dict[key].shape |
| temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) |
| temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear") |
| temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1) |
| model_dict[key] = temp_pe |
| del temp_pe |
|
|
|
|
| def interpolate_pe_to_length_pingpong(model_dict: dict[str, Tensor], key: str, new_length: int): |
| if model_dict[key].shape[1] < new_length: |
| temp_pe = model_dict[key] |
| flipped_temp_pe = torch.flip(temp_pe[:, 1:-1, :], [1]) |
| use_flipped = True |
| preview_pe = None |
| while model_dict[key].shape[1] < new_length: |
| preview_pe = model_dict[key] |
| model_dict[key] = torch.cat([model_dict[key], flipped_temp_pe if use_flipped else temp_pe], dim=1) |
| use_flipped = not use_flipped |
| del temp_pe |
| del flipped_temp_pe |
| del preview_pe |
| model_dict[key] = model_dict[key][:, :new_length] |
|
|
|
|
| def freeze_mask_of_pe(model_dict: dict[str, Tensor], key: str): |
| pe_portion = model_dict[key].shape[2] // 64 |
| first_pe = model_dict[key][:,:1,:] |
| model_dict[key][:,:,pe_portion:] = first_pe[:,:,pe_portion:] |
| del first_pe |
|
|
|
|
| def freeze_mask_of_attn(model_dict: dict[str, Tensor], key: str): |
| attn_portion = model_dict[key].shape[0] // 2 |
| model_dict[key][:attn_portion,:attn_portion] *= 1.5 |
|
|
|
|
| def apply_mm_settings(model_dict: dict[str, Tensor], mm_settings: 'MotionModelSettings') -> dict[str, Tensor]: |
| if mm_settings is None: |
| return model_dict |
| if not mm_settings.has_anything_to_apply(): |
| return model_dict |
| for key in model_dict: |
| if "attention_blocks" in key: |
| if "pos_encoder" in key: |
| |
| if mm_settings.has_motion_pe_stretch(): |
| new_pe_length = model_dict[key].shape[1] + mm_settings.motion_pe_stretch |
| interpolate_pe_to_length(model_dict, key, new_length=new_pe_length) |
| |
| if mm_settings.has_pe_strength(): |
| model_dict[key] *= mm_settings.pe_strength |
| |
| if mm_settings.has_initial_pe_idx_offset(): |
| model_dict[key] = model_dict[key][:, mm_settings.initial_pe_idx_offset:] |
| |
| if mm_settings.has_cap_initial_pe_length(): |
| model_dict[key] = model_dict[key][:, :mm_settings.cap_initial_pe_length] |
| |
| if mm_settings.has_interpolate_pe_to_length(): |
| interpolate_pe_to_length(model_dict, key, new_length=mm_settings.interpolate_pe_to_length) |
| |
| if mm_settings.has_final_pe_idx_offset(): |
| model_dict[key] = model_dict[key][:, mm_settings.final_pe_idx_offset:] |
| else: |
| |
| if mm_settings.has_attn_strength(): |
| model_dict[key] *= mm_settings.attn_strength |
| |
| if mm_settings.has_any_attn_sub_strength(): |
| if "to_q" in key and mm_settings.has_attn_q_strength(): |
| model_dict[key] *= mm_settings.attn_q_strength |
| elif "to_k" in key and mm_settings.has_attn_k_strength(): |
| model_dict[key] *= mm_settings.attn_k_strength |
| elif "to_v" in key and mm_settings.has_attn_v_strength(): |
| model_dict[key] *= mm_settings.attn_v_strength |
| elif "to_out" in key: |
| if key.strip().endswith("weight") and mm_settings.has_attn_out_weight_strength(): |
| model_dict[key] *= mm_settings.attn_out_weight_strength |
| elif key.strip().endswith("bias") and mm_settings.has_attn_out_bias_strength(): |
| model_dict[key] *= mm_settings.attn_out_bias_strength |
| |
| elif mm_settings.has_other_strength(): |
| model_dict[key] *= mm_settings.other_strength |
| return model_dict |
|
|
|
|
| class InjectionParams: |
| def __init__(self, video_length: int, unlimited_area_hack: bool, apply_mm_groupnorm_hack: bool, beta_schedule: str, model_name: str, |
| apply_v2_models_properly: bool=False) -> None: |
| self.video_length = video_length |
| self.full_length = None |
| self.unlimited_area_hack = unlimited_area_hack |
| self.apply_mm_groupnorm_hack = apply_mm_groupnorm_hack |
| self.beta_schedule = beta_schedule |
| self.model_name = model_name |
| self.apply_v2_models_properly = apply_v2_models_properly |
| self.context_length: int = None |
| self.context_stride: int = None |
| self.context_overlap: int = None |
| self.context_schedule: str = None |
| self.closed_loop: bool = False |
| self.sync_context_to_pe = False |
| self.loras: MotionLoraList = None |
| self.motion_model_settings = MotionModelSettings() |
| self.noise_type: str = NoiseType.DEFAULT |
| self.sub_idxs = None |
| |
|
|
| def set_context(self, context_length: int, context_stride: int, context_overlap: int, context_schedule: str, closed_loop: bool, sync_context_to_pe: bool=False): |
| self.context_length = context_length |
| self.context_stride = context_stride |
| self.context_overlap = context_overlap |
| self.context_schedule = context_schedule |
| self.closed_loop = closed_loop |
| self.sync_context_to_pe = sync_context_to_pe |
| |
| def set_loras(self, loras: MotionLoraList): |
| self.loras = loras.clone() |
| |
| def set_motion_model_settings(self, motion_model_settings: 'MotionModelSettings'): |
| if motion_model_settings is None: |
| self.motion_model_settings = MotionModelSettings() |
| else: |
| self.motion_model_settings = motion_model_settings |
|
|
| def reset_context(self): |
| self.context_length = None |
| self.context_stride = None |
| self.context_overlap = None |
| self.context_schedule = None |
| self.closed_loop = False |
| |
| def clone(self) -> 'InjectionParams': |
| new_params = InjectionParams( |
| self.video_length, self.unlimited_area_hack, self.apply_mm_groupnorm_hack, |
| self.beta_schedule, self.model_name, apply_v2_models_properly=self.apply_v2_models_properly, |
| ) |
| new_params.full_length = self.full_length |
| new_params.noise_type = self.noise_type |
| new_params.set_context( |
| context_length=self.context_length, context_stride=self.context_stride, |
| context_overlap=self.context_overlap, context_schedule=self.context_schedule, |
| closed_loop=self.closed_loop, sync_context_to_pe=self.sync_context_to_pe, |
| ) |
| if self.loras is not None: |
| new_params.loras = self.loras.clone() |
| new_params.set_motion_model_settings(self.motion_model_settings) |
| return new_params |
|
|
|
|
| class MotionModelSettings: |
| def __init__(self, |
| pe_strength: float=1.0, |
| attn_strength: float=1.0, |
| attn_q_strength: float=1.0, |
| attn_k_strength: float=1.0, |
| attn_v_strength: float=1.0, |
| attn_out_weight_strength: float=1.0, |
| attn_out_bias_strength: float=1.0, |
| other_strength: float=1.0, |
| cap_initial_pe_length: int=0, interpolate_pe_to_length: int=0, |
| initial_pe_idx_offset: int=0, final_pe_idx_offset: int=0, |
| motion_pe_stretch: int=0, |
| attn_scale: float=1.0, |
| mask_attn_scale: Tensor=None, |
| mask_attn_scale_min: float=1.0, |
| mask_attn_scale_max: float=1.0, |
| ): |
| |
| self.pe_strength = pe_strength |
| self.attn_strength = attn_strength |
| self.other_strength = other_strength |
| |
| self.attn_q_strength = attn_q_strength |
| self.attn_k_strength = attn_k_strength |
| self.attn_v_strength = attn_v_strength |
| self.attn_out_weight_strength = attn_out_weight_strength |
| self.attn_out_bias_strength = attn_out_bias_strength |
| |
| self.cap_initial_pe_length = cap_initial_pe_length |
| self.interpolate_pe_to_length = interpolate_pe_to_length |
| self.initial_pe_idx_offset = initial_pe_idx_offset |
| self.final_pe_idx_offset = final_pe_idx_offset |
| self.motion_pe_stretch = motion_pe_stretch |
| |
| self.attn_scale = attn_scale |
| |
| self.mask_attn_scale = mask_attn_scale.clone() if mask_attn_scale is not None else mask_attn_scale |
| self.mask_attn_scale_min = mask_attn_scale_min |
| self.mask_attn_scale_max = mask_attn_scale_max |
| self._prepare_mask_attn_scale() |
| |
| def _prepare_mask_attn_scale(self): |
| if self.mask_attn_scale is not None: |
| self.mask_attn_scale = normalize_min_max(self.mask_attn_scale, self.mask_attn_scale_min, self.mask_attn_scale_max) |
|
|
| def has_mask_attn_scale(self) -> bool: |
| return self.mask_attn_scale is not None |
|
|
| def has_pe_strength(self) -> bool: |
| return self.pe_strength != 1.0 |
| |
| def has_attn_strength(self) -> bool: |
| return self.attn_strength != 1.0 |
| |
| def has_other_strength(self) -> bool: |
| return self.other_strength != 1.0 |
|
|
| def has_cap_initial_pe_length(self) -> bool: |
| return self.cap_initial_pe_length > 0 |
| |
| def has_interpolate_pe_to_length(self) -> bool: |
| return self.interpolate_pe_to_length > 0 |
| |
| def has_initial_pe_idx_offset(self) -> bool: |
| return self.initial_pe_idx_offset > 0 |
| |
| def has_final_pe_idx_offset(self) -> bool: |
| return self.final_pe_idx_offset > 0 |
|
|
| def has_motion_pe_stretch(self) -> bool: |
| return self.motion_pe_stretch > 0 |
|
|
| def has_anything_to_apply(self) -> bool: |
| return self.has_pe_strength() \ |
| or self.has_attn_strength() \ |
| or self.has_other_strength() \ |
| or self.has_cap_initial_pe_length() \ |
| or self.has_interpolate_pe_to_length() \ |
| or self.has_initial_pe_idx_offset() \ |
| or self.has_final_pe_idx_offset() \ |
| or self.has_motion_pe_stretch() \ |
| or self.has_any_attn_sub_strength() |
|
|
| def has_any_attn_sub_strength(self) -> bool: |
| return self.has_attn_q_strength() \ |
| or self.has_attn_k_strength() \ |
| or self.has_attn_v_strength() \ |
| or self.has_attn_out_weight_strength() \ |
| or self.has_attn_out_bias_strength() |
|
|
| def has_attn_q_strength(self) -> bool: |
| return self.attn_q_strength != 1.0 |
|
|
| def has_attn_k_strength(self) -> bool: |
| return self.attn_k_strength != 1.0 |
|
|
| def has_attn_v_strength(self) -> bool: |
| return self.attn_v_strength != 1.0 |
|
|
| def has_attn_out_weight_strength(self) -> bool: |
| return self.attn_out_weight_strength != 1.0 |
|
|
| def has_attn_out_bias_strength(self) -> bool: |
| return self.attn_out_bias_strength != 1.0 |
|
|