|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import Union |
|
|
from torch import Tensor |
|
|
import torch |
|
|
import os |
|
|
|
|
|
import comfy.utils |
|
|
from comfy.controlnet import ControlBase |
|
|
|
|
|
from .logger import logger |
|
|
from .utils import AdvancedControlBase, deepcopy_with_sharing, prepare_mask_batch |
|
|
|
|
|
|
|
|
def extra_options_to_module_prefix(extra_options): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block = extra_options["block"] |
|
|
block_index = extra_options["block_index"] |
|
|
if block[0] == "input": |
|
|
module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}" |
|
|
elif block[0] == "middle": |
|
|
module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}" |
|
|
elif block[0] == "output": |
|
|
module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}" |
|
|
else: |
|
|
raise Exception(f"ControlLLLite: invalid block name '{block[0]}'. Expected 'input', 'middle', or 'output'.") |
|
|
return module_pfx |
|
|
|
|
|
|
|
|
class LLLitePatch: |
|
|
ATTN1 = "attn1" |
|
|
ATTN2 = "attn2" |
|
|
def __init__(self, modules: dict[str, 'LLLiteModule'], patch_type: str, control: Union[AdvancedControlBase, ControlBase]=None): |
|
|
self.modules = modules |
|
|
self.control = control |
|
|
self.patch_type = patch_type |
|
|
|
|
|
|
|
|
def __call__(self, q, k, v, extra_options): |
|
|
|
|
|
|
|
|
if self.control.timestep_range is not None: |
|
|
|
|
|
|
|
|
if self.control.t > self.control.timestep_range[0] or self.control.t < self.control.timestep_range[1]: |
|
|
return q, k, v |
|
|
|
|
|
module_pfx = extra_options_to_module_prefix(extra_options) |
|
|
|
|
|
is_attn1 = q.shape[-1] == k.shape[-1] |
|
|
if is_attn1: |
|
|
module_pfx = module_pfx + "_attn1" |
|
|
else: |
|
|
module_pfx = module_pfx + "_attn2" |
|
|
|
|
|
module_pfx_to_q = module_pfx + "_to_q" |
|
|
module_pfx_to_k = module_pfx + "_to_k" |
|
|
module_pfx_to_v = module_pfx + "_to_v" |
|
|
|
|
|
if module_pfx_to_q in self.modules: |
|
|
q = q + self.modules[module_pfx_to_q](q, self.control) |
|
|
if module_pfx_to_k in self.modules: |
|
|
k = k + self.modules[module_pfx_to_k](k, self.control) |
|
|
if module_pfx_to_v in self.modules: |
|
|
v = v + self.modules[module_pfx_to_v](v, self.control) |
|
|
|
|
|
return q, k, v |
|
|
|
|
|
def to(self, device): |
|
|
|
|
|
for d in self.modules.keys(): |
|
|
self.modules[d] = self.modules[d].to(device) |
|
|
return self |
|
|
|
|
|
def set_control(self, control: Union[AdvancedControlBase, ControlBase]) -> 'LLLitePatch': |
|
|
self.control = control |
|
|
return self |
|
|
|
|
|
|
|
|
def clone_with_control(self, control: AdvancedControlBase): |
|
|
|
|
|
return LLLitePatch(self.modules, self.patch_type, control) |
|
|
|
|
|
def cleanup(self): |
|
|
|
|
|
for module in self.modules.values(): |
|
|
module.cleanup() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __deepcopy__(self, memo): |
|
|
self.cleanup() |
|
|
to_return: LLLitePatch = deepcopy_with_sharing(self, shared_attribute_names = ['control'], memo=memo) |
|
|
|
|
|
try: |
|
|
if self.patch_type == self.ATTN1: |
|
|
to_return.control.patch_attn1 = to_return |
|
|
elif self.patch_type == self.ATTN2: |
|
|
to_return.control.patch_attn2 = to_return |
|
|
except Exception: |
|
|
pass |
|
|
return to_return |
|
|
|
|
|
|
|
|
|
|
|
class LLLiteModule(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
is_conv2d: bool, |
|
|
in_dim: int, |
|
|
depth: int, |
|
|
cond_emb_dim: int, |
|
|
mlp_dim: int, |
|
|
): |
|
|
super().__init__() |
|
|
self.name = name |
|
|
self.is_conv2d = is_conv2d |
|
|
self.is_first = False |
|
|
|
|
|
modules = [] |
|
|
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) |
|
|
if depth == 1: |
|
|
modules.append(torch.nn.ReLU(inplace=True)) |
|
|
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) |
|
|
elif depth == 2: |
|
|
modules.append(torch.nn.ReLU(inplace=True)) |
|
|
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) |
|
|
elif depth == 3: |
|
|
|
|
|
modules.append(torch.nn.ReLU(inplace=True)) |
|
|
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) |
|
|
modules.append(torch.nn.ReLU(inplace=True)) |
|
|
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) |
|
|
|
|
|
self.conditioning1 = torch.nn.Sequential(*modules) |
|
|
|
|
|
if self.is_conv2d: |
|
|
self.down = torch.nn.Sequential( |
|
|
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), |
|
|
torch.nn.ReLU(inplace=True), |
|
|
) |
|
|
self.mid = torch.nn.Sequential( |
|
|
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), |
|
|
torch.nn.ReLU(inplace=True), |
|
|
) |
|
|
self.up = torch.nn.Sequential( |
|
|
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), |
|
|
) |
|
|
else: |
|
|
self.down = torch.nn.Sequential( |
|
|
torch.nn.Linear(in_dim, mlp_dim), |
|
|
torch.nn.ReLU(inplace=True), |
|
|
) |
|
|
self.mid = torch.nn.Sequential( |
|
|
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), |
|
|
torch.nn.ReLU(inplace=True), |
|
|
) |
|
|
self.up = torch.nn.Sequential( |
|
|
torch.nn.Linear(mlp_dim, in_dim), |
|
|
) |
|
|
|
|
|
self.depth = depth |
|
|
self.cond_emb = None |
|
|
self.cx_shape = None |
|
|
self.prev_batch = 0 |
|
|
self.prev_sub_idxs = None |
|
|
|
|
|
def cleanup(self): |
|
|
del self.cond_emb |
|
|
self.cond_emb = None |
|
|
self.cx_shape = None |
|
|
self.prev_batch = 0 |
|
|
self.prev_sub_idxs = None |
|
|
|
|
|
def forward(self, x: Tensor, control: Union[AdvancedControlBase, ControlBase]): |
|
|
mask = None |
|
|
mask_tk = None |
|
|
|
|
|
if self.cond_emb is None or control.sub_idxs != self.prev_sub_idxs or x.shape[0] != self.prev_batch: |
|
|
|
|
|
cond_hint = control.cond_hint.to(x.device, dtype=x.dtype) |
|
|
if control.latent_dims_div2 is not None and x.shape[-1] != 1280: |
|
|
cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div2[0] * 8, control.latent_dims_div2[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype) |
|
|
elif control.latent_dims_div4 is not None and x.shape[-1] == 1280: |
|
|
cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div4[0] * 8, control.latent_dims_div4[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype) |
|
|
cx = self.conditioning1(cond_hint) |
|
|
self.cx_shape = cx.shape |
|
|
if not self.is_conv2d: |
|
|
|
|
|
n, c, h, w = cx.shape |
|
|
cx = cx.view(n, c, h * w).permute(0, 2, 1) |
|
|
self.cond_emb = cx |
|
|
|
|
|
self.prev_batch = x.shape[0] |
|
|
self.prev_sub_idxs = control.sub_idxs |
|
|
|
|
|
cx: torch.Tensor = self.cond_emb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.is_conv2d: |
|
|
n, c, h, w = self.cx_shape |
|
|
if control.mask_cond_hint is not None: |
|
|
mask = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype) |
|
|
mask = mask.view(mask.shape[0], 1, h * w).permute(0, 2, 1) |
|
|
if control.tk_mask_cond_hint is not None: |
|
|
mask_tk = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype) |
|
|
mask_tk = mask_tk.view(mask_tk.shape[0], 1, h * w).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
if x.shape[0] != cx.shape[0]: |
|
|
if self.is_conv2d: |
|
|
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1) |
|
|
else: |
|
|
|
|
|
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1) |
|
|
if mask is not None: |
|
|
mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1) |
|
|
if mask_tk is not None: |
|
|
mask_tk = mask_tk.repeat(x.shape[0] // mask_tk.shape[0], 1, 1) |
|
|
|
|
|
if mask is None: |
|
|
mask = 1.0 |
|
|
elif mask_tk is not None: |
|
|
mask = mask * mask_tk |
|
|
|
|
|
|
|
|
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2) |
|
|
cx = self.mid(cx) |
|
|
cx = self.up(cx) |
|
|
if control.latent_keyframes is not None: |
|
|
cx = cx * control.calc_latent_keyframe_mults(x=cx, batched_number=control.batched_number) |
|
|
if control.weights is not None and control.weights.has_uncond_multiplier: |
|
|
cond_or_uncond = control.batched_number.cond_or_uncond |
|
|
actual_length = cx.size(0) // control.batched_number |
|
|
for idx, cond_type in enumerate(cond_or_uncond): |
|
|
|
|
|
if cond_type == 1: |
|
|
cx[actual_length*idx:actual_length*(idx+1)] *= control.weights.uncond_multiplier |
|
|
return cx * mask * control.strength * control._current_timestep_keyframe.strength |
|
|
|