| from abc import ABC, abstractmethod |
| from typing import Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
|
|
| import comfy.model_management as model_management |
| import comfy.ops |
| from comfy.cli_args import args |
| from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default |
| from comfy.controlnet import broadcast_image_to |
| from comfy.utils import repeat_to_batch_size |
|
|
| from .motion_lora import MotionLoraInfo |
| from .logger import logger |
|
|
|
|
| |
| |
| optimized_attention_mm = attention_basic |
| if model_management.xformers_enabled(): |
| pass |
| |
| if model_management.pytorch_attention_enabled(): |
| optimized_attention_mm = attention_pytorch |
| else: |
| if args.use_split_cross_attention: |
| optimized_attention_mm = attention_split |
| else: |
| optimized_attention_mm = attention_sub_quad |
|
|
|
|
| |
| class CrossAttentionMM(nn.Module): |
| def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, |
| operations=comfy.ops.disable_weight_init if hasattr(comfy.ops, "disable_weight_init") else comfy.ops): |
| super().__init__() |
| inner_dim = dim_head * heads |
| context_dim = default(context_dim, query_dim) |
|
|
| self.heads = heads |
| self.dim_head = dim_head |
| self.scale = None |
| self.default_scale = dim_head ** -0.5 |
|
|
| self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) |
| self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) |
| self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
|
| self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) |
|
|
| def forward(self, x, context=None, value=None, mask=None, scale_mask=None): |
| q = self.to_q(x) |
| context = default(context, x) |
| k: Tensor = self.to_k(context) |
| if value is not None: |
| v = self.to_v(value) |
| del value |
| else: |
| v = self.to_v(context) |
|
|
| |
| if self.scale is not None: |
| k *= self.scale |
| |
| |
| if scale_mask is not None: |
| k *= scale_mask |
|
|
| out = optimized_attention_mm(q, k, v, self.heads, mask) |
| return self.to_out(out) |
|
|
|
|
| |
| class TemporalTransformerGeneric: |
| def temporal_transformer_init(self, default_length: int): |
| self.video_length = default_length |
| self.full_length = default_length |
| self.scale_min = 1.0 |
| self.scale_max = 1.0 |
| self.raw_scale_mask: Union[Tensor, None] = None |
| self.temp_scale_mask: Union[Tensor, None] = None |
| self.sub_idxs: Union[list[int], None] = None |
| self.prev_hidden_states_batch = 0 |
|
|
| def reset_temp_vars(self): |
| del self.temp_scale_mask |
| self.temp_scale_mask = None |
| self.prev_hidden_states_batch = 0 |
|
|
| def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: |
| |
| if self.raw_scale_mask is None: |
| return None |
| shape = hidden_states.shape |
| batch, channel, height, width = shape |
| |
| if self.temp_scale_mask != None: |
| |
| if batch == self.prev_hidden_states_batch: |
| if self.sub_idxs is not None: |
| return self.temp_scale_mask[:, self.sub_idxs, :] |
| return self.temp_scale_mask |
| |
| del self.temp_scale_mask |
| self.temp_scale_mask = None |
| |
| self.prev_hidden_states_batch = batch |
| mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) |
| mask = repeat_to_batch_size(mask, self.full_length) |
| |
| if self.full_length != mask.shape[0]: |
| mask = broadcast_image_to(mask, self.full_length, 1) |
| |
| batch, channel, height, width = mask.shape |
| |
| |
| mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) |
| |
| mask = mask.permute(1, 0, 2) |
| |
| batched_number = shape[0] // self.video_length |
| if batched_number > 1: |
| mask = torch.cat([mask] * batched_number, dim=0) |
| |
| self.temp_scale_mask = mask |
| |
| self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device) |
| |
| if self.sub_idxs is not None: |
| return self.temp_scale_mask[:, self.sub_idxs, :] |
| return self.temp_scale_mask |
|
|
|
|
| class BlockType: |
| UP = "up" |
| DOWN = "down" |
| MID = "mid" |
|
|
|
|
| class GenericMotionWrapper(nn.Module, ABC): |
| def __init__(self, mm_hash: str, mm_name: str, loras: list[MotionLoraInfo]): |
| super().__init__() |
| self.down_blocks: nn.ModuleList = None |
| self.up_blocks: nn.ModuleList = None |
| self.mid_block = None |
| self.mm_hash = mm_hash |
| self.mm_name = mm_name |
| self.version = "FILLTHISIN" |
| self.injector_version = "VERYIMPORTANT_FILLTHISIN" |
| self.AD_video_length: int = 0 |
| self.loras = loras |
|
|
| def has_loras(self) -> bool: |
| |
| |
| return self.loras is not None |
|
|
| @abstractmethod |
| def set_video_length(self, video_length: int, full_length: int): |
| pass |
|
|
| @abstractmethod |
| def set_scale_multiplier(self, multiplier: Union[float, None]): |
| pass |
|
|
| @abstractmethod |
| def set_masks(self, masks: Tensor, min_val: float, max_val: float): |
| pass |
|
|
| @abstractmethod |
| def set_sub_idxs(self, sub_idxs: list[int]): |
| pass |
| |
| @abstractmethod |
| def reset_temp_vars(self): |
| pass |
|
|
| def reset_scale_multiplier(self): |
| self.set_scale_multiplier(None) |
|
|
| def reset_sub_idxs(self): |
| self.set_sub_idxs(None) |
|
|
| def reset(self): |
| self.reset_sub_idxs() |
| self.reset_scale_multiplier() |
| self.reset_temp_vars() |
|
|
|
|
| class GroupNormAD(torch.nn.GroupNorm): |
| def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, |
| device=None, dtype=None) -> None: |
| super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine, device=device, dtype=dtype) |
| |
| def forward(self, input: Tensor) -> Tensor: |
| return F.group_norm( |
| input, self.num_groups, self.weight, self.bias, self.eps) |
|
|
|
|
| |
| |
| def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0): |
| x_min, x_max = x.min(), x.max() |
| return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min |
|
|
|
|
| |
| def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False): |
| mask = mask.clone() |
| mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear") |
| if match_dim1: |
| mask = torch.cat([mask] * shape[1], dim=1) |
| return mask |
|
|
|
|
| class NoiseType: |
| DEFAULT = "default" |
| REPEATED = "repeated" |
| CONSTANT = "constant" |
| AUTO1111 = "auto1111" |
|
|
| LIST = [DEFAULT, REPEATED, CONSTANT, AUTO1111] |
|
|
| @classmethod |
| def prepare_noise(cls, noise_type: str, latents: Tensor, noise: Tensor, context_length: int, seed: int): |
| if noise_type == cls.DEFAULT: |
| return noise |
| elif noise_type == cls.REPEATED: |
| return cls.prepare_noise_repeated(latents, noise, context_length, seed) |
| elif noise_type == cls.CONSTANT: |
| return cls.prepare_noise_constant(latents, noise, context_length, seed) |
| elif noise_type == cls.AUTO1111: |
| return cls.prepare_noise_auto1111(latents, noise, context_length, seed) |
| logger.warning(f"Noise type {noise_type} not recognized, proceeding with default noise.") |
| return noise |
|
|
| @classmethod |
| def prepare_noise_repeated(cls, latents: Tensor, noise: Tensor, context_length: int, seed: int): |
| if not context_length: |
| return noise |
| length = latents.shape[0] |
| generator = torch.manual_seed(seed) |
| noise = torch.randn(latents.size(), dtype=latents.dtype, layout=latents.layout, generator=generator, device="cpu") |
| noise_set = noise[:context_length] |
| cat_count = (length // context_length) + 1 |
| noise_set = torch.cat([noise_set] * cat_count, dim=0) |
| noise_set = noise_set[:length] |
| return noise_set |
|
|
| @classmethod |
| def prepare_noise_constant(cls, latents: Tensor, noise: Tensor, context_length: int, seed: int): |
| length = latents.shape[0] |
| single_shape = (1, latents.shape[1], latents.shape[2], latents.shape[3]) |
| generator = torch.manual_seed(seed) |
| noise = torch.randn(single_shape, dtype=latents.dtype, layout=latents.layout, generator=generator, device="cpu") |
| return torch.cat([noise] * length, dim=0) |
|
|
| @classmethod |
| def prepare_noise_auto1111(cls, latents: Tensor, noise: Tensor, context_length: int, seed: int): |
| |
| length = latents.shape[0] |
| single_shape = (1, latents.shape[1], latents.shape[2], latents.shape[3]) |
| all_noises = [] |
| |
| for i in range(length): |
| generator = torch.manual_seed(seed+i) |
| all_noises.append(torch.randn(single_shape, dtype=latents.dtype, layout=latents.layout, generator=generator, device="cpu")) |
| return torch.cat(all_noises, dim=0) |
|
|
|
|
| class MotionCompatibilityError(ValueError): |
| pass |
|
|