Metaphysix2's picture
Upload folder using huggingface_hub
3e5f61c
import copy
from einops import rearrange
from torch import Tensor
import torch.nn.functional as F
import torch
import comfy.model_management
import comfy.utils
from comfy.model_patcher import ModelPatcher
from .motion_module_ad import AnimateDiffModel, has_mid_block, normalize_ad_state_dict
from .logger import logger
from .motion_utils import MotionCompatibilityError, NoiseType, normalize_min_max
from .motion_lora import MotionLoraInfo, MotionLoraList
from .model_utils import get_motion_lora_path, get_motion_model_path, get_sd_model_type
# some motion_model casts here might fail if model becomes metatensor or is not castable;
# should not really matter if it fails, so ignore raised Exceptions
class ModelPatcherAndInjector(ModelPatcher):
def __init__(self, m: ModelPatcher):
# replicate ModelPatcher.clone() to initialize ModelPatcherAndInjector
super().__init__(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
self.patches = {}
for k in m.patches:
self.patches[k] = m.patches[k][:]
self.object_patches = m.object_patches.copy()
self.model_options = copy.deepcopy(m.model_options)
self.model_keys = m.model_keys
# injection stuff
self.motion_injection_params: InjectionParams = None
self.motion_model: MotionModelPatcher = None
self.motion_model_sampling = None
def model_patches_to(self, device):
super().model_patches_to(device)
if self.motion_model is not None:
try:
self.motion_model.model.to(device)
except Exception:
pass
def patch_model(self, device_to=None):
# first, perform model patching
patched_model = super().patch_model(device_to)
# finally, perform motion model injection
self.inject_model(device_to=device_to)
return patched_model
def unpatch_model(self, device_to=None):
# first, eject motion model from unet
self.eject_model(device_to=device_to)
# finally, do normal model unpatching
return super().unpatch_model(device_to)
def inject_model(self, device_to=None):
if self.motion_model is not None:
self.motion_model.model.eject(self)
self.motion_model.model.inject(self)
try:
self.motion_model.model.to(device_to)
except Exception:
pass
def eject_model(self, device_to=None):
if self.motion_model is not None:
self.motion_model.model.eject(self)
try:
self.motion_model.model.to(device_to)
except Exception:
pass
def clone(self):
cloned = ModelPatcherAndInjector(self)
cloned.motion_model = self.motion_model
cloned.motion_injection_params = self.motion_injection_params
cloned.motion_model_sampling = self.motion_model_sampling
return cloned
class MotionModelPatcher(ModelPatcher):
# Mostly here so that type hints work in IDEs
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model: AnimateDiffModel = self.model
def patch_model(self, *args, **kwargs):
# patch as normal, but prepare_weights so that lowvram meta device works properly
patched_model = super().patch_model(*args, **kwargs)
self.prepare_weights()
return patched_model
def prepare_weights(self):
# in case lowvram is active and meta device is used, need to convert weights
# otherwise, will get exceptions thrown related to meta device
state_dict = self.model.state_dict()
for key in state_dict:
weight = comfy.model_management.resolve_lowvram_weight(state_dict[key], self.model, key)
try:
comfy.utils.set_attr(self.model, key, weight)
except Exception:
pass
def pre_run(self):
# just in case, prepare_weights before every run
self.prepare_weights()
def cleanup(self):
if self.model is not None:
self.model.cleanup()
def get_vanilla_model_patcher(m: ModelPatcher) -> ModelPatcher:
model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
model.patches = {}
for k in m.patches:
model.patches[k] = m.patches[k][:]
model.object_patches = m.object_patches.copy()
model.model_options = copy.deepcopy(m.model_options)
model.model_keys = m.model_keys
return model
# adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py
# Example LoRA keys:
# down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.processor.to_q_lora.down.weight
# down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.processor.to_q_lora.up.weight
#
# Example model keys:
# down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight
#
def load_motion_lora_as_patches(motion_model: MotionModelPatcher, lora: MotionLoraInfo) -> None:
def get_version(has_midblock: bool):
return "v2" if has_midblock else "v1"
lora_path = get_motion_lora_path(lora.name)
logger.info(f"Loading motion LoRA {lora.name}")
state_dict = comfy.utils.load_torch_file(lora_path)
# remove all non-temporal keys (in case model has extra stuff in it)
for key in list(state_dict.keys()):
if "temporal" not in key:
del state_dict[key]
if len(state_dict) == 0:
raise ValueError(f"'{lora.name}' contains no temporal keys; it is not a valid motion LoRA!")
model_has_midblock = motion_model.model.mid_block != None
lora_has_midblock = has_mid_block(state_dict)
logger.info(f"Applying a {get_version(lora_has_midblock)} LoRA ({lora.name}) to a { motion_model.model.mm_info.mm_version} motion model.")
patches = {}
# convert lora state dict to one that matches motion_module keys and tensors
for key in state_dict:
# if motion_module doesn't have a midblock, skip mid_block entries
if not model_has_midblock:
if "mid_block" in key: continue
# only process lora down key (we will process up at the same time as down)
if "up." in key: continue
# get up key version of down key
up_key = key.replace(".down.", ".up.")
# adapt key to match motion_module key format - remove 'processor.', '_lora', 'down.', and 'up.'
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
# motion_module keys have a '0.' after all 'to_out.' weight keys
model_key = model_key.replace("to_out.", "to_out.0.")
weight_down = state_dict[key]
weight_up = state_dict[up_key]
# actual weights obtained by matrix multiplication of up and down weights
# save as a tuple, so that (Motion)ModelPatcher's calculate_weight function detects len==1, applying it correctly
patches[model_key] = (torch.mm(weight_up, weight_down),)
del state_dict
# add patches to motion ModelPatcher
motion_model.add_patches(patches=patches, strength_patch=lora.strength)
def load_motion_module(model_name: str, model: ModelPatcher, motion_lora: MotionLoraList = None, motion_model_settings: 'MotionModelSettings' = None) -> MotionModelPatcher:
model_path = get_motion_model_path(model_name)
logger.info(f"Loading motion module {model_name}")
mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True)
# TODO: check for empty state dict?
# get normalized state_dict and motion model info
mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name)
# check that motion model is compatible with sd model
model_sd_type = get_sd_model_type(model)
if model_sd_type != mm_info.sd_type:
raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \
+ f"but the provided model is type {model_sd_type}.")
# apply motion model settings
mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings)
# initialize AnimateDiffModelWrapper
ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info)
ad_wrapper.to(model.model_dtype())
ad_wrapper.to(model.offload_device)
load_result = ad_wrapper.load_state_dict(mm_state_dict)
# TODO: report load_result of motion_module loading?
# wrap motion_module into a ModelPatcher, to allow motion lora patches
motion_model = MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device)
# load motion_lora, if present
if motion_lora is not None:
for lora in motion_lora.loras:
load_motion_lora_as_patches(motion_model, lora)
return motion_model
def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int):
pe_shape = model_dict[key].shape
temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1)
temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear")
temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1)
model_dict[key] = temp_pe
del temp_pe
def interpolate_pe_to_length_diffs(model_dict: dict[str, Tensor], key: str, new_length: int):
# TODO: fill out and try out
pe_shape = model_dict[key].shape
temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1)
temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear")
temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1)
model_dict[key] = temp_pe
del temp_pe
def interpolate_pe_to_length_pingpong(model_dict: dict[str, Tensor], key: str, new_length: int):
if model_dict[key].shape[1] < new_length:
temp_pe = model_dict[key]
flipped_temp_pe = torch.flip(temp_pe[:, 1:-1, :], [1])
use_flipped = True
preview_pe = None
while model_dict[key].shape[1] < new_length:
preview_pe = model_dict[key]
model_dict[key] = torch.cat([model_dict[key], flipped_temp_pe if use_flipped else temp_pe], dim=1)
use_flipped = not use_flipped
del temp_pe
del flipped_temp_pe
del preview_pe
model_dict[key] = model_dict[key][:, :new_length]
def freeze_mask_of_pe(model_dict: dict[str, Tensor], key: str):
pe_portion = model_dict[key].shape[2] // 64
first_pe = model_dict[key][:,:1,:]
model_dict[key][:,:,pe_portion:] = first_pe[:,:,pe_portion:]
del first_pe
def freeze_mask_of_attn(model_dict: dict[str, Tensor], key: str):
attn_portion = model_dict[key].shape[0] // 2
model_dict[key][:attn_portion,:attn_portion] *= 1.5
def apply_mm_settings(model_dict: dict[str, Tensor], mm_settings: 'MotionModelSettings') -> dict[str, Tensor]:
if mm_settings is None:
return model_dict
if not mm_settings.has_anything_to_apply():
return model_dict
for key in model_dict:
if "attention_blocks" in key:
if "pos_encoder" in key:
# apply simple motion pe stretch, if needed
if mm_settings.has_motion_pe_stretch():
new_pe_length = model_dict[key].shape[1] + mm_settings.motion_pe_stretch
interpolate_pe_to_length(model_dict, key, new_length=new_pe_length)
# apply pe_strength, if needed
if mm_settings.has_pe_strength():
model_dict[key] *= mm_settings.pe_strength
# apply pe_idx_offset, if needed
if mm_settings.has_initial_pe_idx_offset():
model_dict[key] = model_dict[key][:, mm_settings.initial_pe_idx_offset:]
# apply has_cap_initial_pe_length, if needed
if mm_settings.has_cap_initial_pe_length():
model_dict[key] = model_dict[key][:, :mm_settings.cap_initial_pe_length]
# apply interpolate_pe_to_length, if needed
if mm_settings.has_interpolate_pe_to_length():
interpolate_pe_to_length(model_dict, key, new_length=mm_settings.interpolate_pe_to_length)
# apply final_pe_idx_offset, if needed
if mm_settings.has_final_pe_idx_offset():
model_dict[key] = model_dict[key][:, mm_settings.final_pe_idx_offset:]
else:
# apply attn_strenth, if needed
if mm_settings.has_attn_strength():
model_dict[key] *= mm_settings.attn_strength
# apply specific attn_strengths, if needed
if mm_settings.has_any_attn_sub_strength():
if "to_q" in key and mm_settings.has_attn_q_strength():
model_dict[key] *= mm_settings.attn_q_strength
elif "to_k" in key and mm_settings.has_attn_k_strength():
model_dict[key] *= mm_settings.attn_k_strength
elif "to_v" in key and mm_settings.has_attn_v_strength():
model_dict[key] *= mm_settings.attn_v_strength
elif "to_out" in key:
if key.strip().endswith("weight") and mm_settings.has_attn_out_weight_strength():
model_dict[key] *= mm_settings.attn_out_weight_strength
elif key.strip().endswith("bias") and mm_settings.has_attn_out_bias_strength():
model_dict[key] *= mm_settings.attn_out_bias_strength
# apply other strength, if needed
elif mm_settings.has_other_strength():
model_dict[key] *= mm_settings.other_strength
return model_dict
class InjectionParams:
def __init__(self, video_length: int, unlimited_area_hack: bool, apply_mm_groupnorm_hack: bool, beta_schedule: str, model_name: str,
apply_v2_models_properly: bool=False) -> None:
self.video_length = video_length
self.full_length = None
self.unlimited_area_hack = unlimited_area_hack
self.apply_mm_groupnorm_hack = apply_mm_groupnorm_hack
self.beta_schedule = beta_schedule
self.model_name = model_name
self.apply_v2_models_properly = apply_v2_models_properly
self.context_length: int = None
self.context_stride: int = None
self.context_overlap: int = None
self.context_schedule: str = None
self.closed_loop: bool = False
self.sync_context_to_pe = False
self.loras: MotionLoraList = None
self.motion_model_settings = MotionModelSettings()
self.noise_type: str = NoiseType.DEFAULT
self.sub_idxs = None # value should NOT be included in clone, so it will auto reset
def set_context(self, context_length: int, context_stride: int, context_overlap: int, context_schedule: str, closed_loop: bool, sync_context_to_pe: bool=False):
self.context_length = context_length
self.context_stride = context_stride
self.context_overlap = context_overlap
self.context_schedule = context_schedule
self.closed_loop = closed_loop
self.sync_context_to_pe = sync_context_to_pe
def set_loras(self, loras: MotionLoraList):
self.loras = loras.clone()
def set_motion_model_settings(self, motion_model_settings: 'MotionModelSettings'):
if motion_model_settings is None:
self.motion_model_settings = MotionModelSettings()
else:
self.motion_model_settings = motion_model_settings
def reset_context(self):
self.context_length = None
self.context_stride = None
self.context_overlap = None
self.context_schedule = None
self.closed_loop = False
def clone(self) -> 'InjectionParams':
new_params = InjectionParams(
self.video_length, self.unlimited_area_hack, self.apply_mm_groupnorm_hack,
self.beta_schedule, self.model_name, apply_v2_models_properly=self.apply_v2_models_properly,
)
new_params.full_length = self.full_length
new_params.noise_type = self.noise_type
new_params.set_context(
context_length=self.context_length, context_stride=self.context_stride,
context_overlap=self.context_overlap, context_schedule=self.context_schedule,
closed_loop=self.closed_loop, sync_context_to_pe=self.sync_context_to_pe,
)
if self.loras is not None:
new_params.loras = self.loras.clone()
new_params.set_motion_model_settings(self.motion_model_settings)
return new_params
class MotionModelSettings:
def __init__(self,
pe_strength: float=1.0,
attn_strength: float=1.0,
attn_q_strength: float=1.0,
attn_k_strength: float=1.0,
attn_v_strength: float=1.0,
attn_out_weight_strength: float=1.0,
attn_out_bias_strength: float=1.0,
other_strength: float=1.0,
cap_initial_pe_length: int=0, interpolate_pe_to_length: int=0,
initial_pe_idx_offset: int=0, final_pe_idx_offset: int=0,
motion_pe_stretch: int=0,
attn_scale: float=1.0,
mask_attn_scale: Tensor=None,
mask_attn_scale_min: float=1.0,
mask_attn_scale_max: float=1.0,
):
# general strengths
self.pe_strength = pe_strength
self.attn_strength = attn_strength
self.other_strength = other_strength
# specific attn strengths
self.attn_q_strength = attn_q_strength
self.attn_k_strength = attn_k_strength
self.attn_v_strength = attn_v_strength
self.attn_out_weight_strength = attn_out_weight_strength
self.attn_out_bias_strength = attn_out_bias_strength
# PE-interpolation settings
self.cap_initial_pe_length = cap_initial_pe_length
self.interpolate_pe_to_length = interpolate_pe_to_length
self.initial_pe_idx_offset = initial_pe_idx_offset
self.final_pe_idx_offset = final_pe_idx_offset
self.motion_pe_stretch = motion_pe_stretch
# attention scale settings
self.attn_scale = attn_scale
# attention scale mask settings
self.mask_attn_scale = mask_attn_scale.clone() if mask_attn_scale is not None else mask_attn_scale
self.mask_attn_scale_min = mask_attn_scale_min
self.mask_attn_scale_max = mask_attn_scale_max
self._prepare_mask_attn_scale()
def _prepare_mask_attn_scale(self):
if self.mask_attn_scale is not None:
self.mask_attn_scale = normalize_min_max(self.mask_attn_scale, self.mask_attn_scale_min, self.mask_attn_scale_max)
def has_mask_attn_scale(self) -> bool:
return self.mask_attn_scale is not None
def has_pe_strength(self) -> bool:
return self.pe_strength != 1.0
def has_attn_strength(self) -> bool:
return self.attn_strength != 1.0
def has_other_strength(self) -> bool:
return self.other_strength != 1.0
def has_cap_initial_pe_length(self) -> bool:
return self.cap_initial_pe_length > 0
def has_interpolate_pe_to_length(self) -> bool:
return self.interpolate_pe_to_length > 0
def has_initial_pe_idx_offset(self) -> bool:
return self.initial_pe_idx_offset > 0
def has_final_pe_idx_offset(self) -> bool:
return self.final_pe_idx_offset > 0
def has_motion_pe_stretch(self) -> bool:
return self.motion_pe_stretch > 0
def has_anything_to_apply(self) -> bool:
return self.has_pe_strength() \
or self.has_attn_strength() \
or self.has_other_strength() \
or self.has_cap_initial_pe_length() \
or self.has_interpolate_pe_to_length() \
or self.has_initial_pe_idx_offset() \
or self.has_final_pe_idx_offset() \
or self.has_motion_pe_stretch() \
or self.has_any_attn_sub_strength()
def has_any_attn_sub_strength(self) -> bool:
return self.has_attn_q_strength() \
or self.has_attn_k_strength() \
or self.has_attn_v_strength() \
or self.has_attn_out_weight_strength() \
or self.has_attn_out_bias_strength()
def has_attn_q_strength(self) -> bool:
return self.attn_q_strength != 1.0
def has_attn_k_strength(self) -> bool:
return self.attn_k_strength != 1.0
def has_attn_v_strength(self) -> bool:
return self.attn_v_strength != 1.0
def has_attn_out_weight_strength(self) -> bool:
return self.attn_out_weight_strength != 1.0
def has_attn_out_bias_strength(self) -> bool:
return self.attn_out_bias_strength != 1.0