| from typing import List, NamedTuple, Any
|
| import numpy as np
|
| import cv2
|
| import torch
|
| from safetensors.torch import load_file
|
|
|
| from library.original_unet import UNet2DConditionModel, SampleOutput
|
|
|
| import library.model_util as model_util
|
|
|
|
|
| class ControlNetInfo(NamedTuple):
|
| unet: Any
|
| net: Any
|
| prep: Any
|
| weight: float
|
| ratio: float
|
|
|
|
|
| class ControlNet(torch.nn.Module):
|
| def __init__(self) -> None:
|
| super().__init__()
|
|
|
|
|
| self.control_model = torch.nn.Module()
|
|
|
| dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
|
| zero_convs = torch.nn.ModuleList()
|
| for i, dim in enumerate(dims):
|
| sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
|
| zero_convs.append(sub_list)
|
| self.control_model.add_module("zero_convs", zero_convs)
|
|
|
| middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
|
| self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
|
|
|
| dims = [16, 16, 32, 32, 96, 96, 256, 320]
|
| strides = [1, 1, 2, 1, 2, 1, 2, 1]
|
| prev_dim = 3
|
| input_hint_block = torch.nn.Sequential()
|
| for i, (dim, stride) in enumerate(zip(dims, strides)):
|
| input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
|
| if i < len(dims) - 1:
|
| input_hint_block.append(torch.nn.SiLU())
|
| prev_dim = dim
|
| self.control_model.add_module("input_hint_block", input_hint_block)
|
|
|
|
|
| def load_control_net(v2, unet, model):
|
| device = unet.device
|
|
|
|
|
|
|
| print(f"ControlNet: loading control SD model : {model}")
|
|
|
| if model_util.is_safetensors(model):
|
| ctrl_sd_sd = load_file(model)
|
| else:
|
| ctrl_sd_sd = torch.load(model, map_location="cpu")
|
| ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
|
|
|
|
| is_difference = "difference" in ctrl_sd_sd
|
| print("ControlNet: loading difference:", is_difference)
|
|
|
|
|
|
|
| ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
|
|
|
|
| for key in list(ctrl_unet_sd_sd.keys()):
|
| ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
|
|
|
| zero_conv_sd = {}
|
| for key in list(ctrl_sd_sd.keys()):
|
| if key.startswith("control_"):
|
| unet_key = "model.diffusion_" + key[len("control_") :]
|
| if unet_key not in ctrl_unet_sd_sd:
|
| zero_conv_sd[key] = ctrl_sd_sd[key]
|
| continue
|
| if is_difference:
|
| ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
| else:
|
| ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
|
|
| unet_config = model_util.create_unet_diffusers_config(v2)
|
| ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config)
|
|
|
|
|
| ctrl_unet = UNet2DConditionModel(**unet_config)
|
| info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
| print("ControlNet: loading Control U-Net:", info)
|
|
|
|
|
|
|
| ctrl_net = ControlNet()
|
| info = ctrl_net.load_state_dict(zero_conv_sd)
|
| print("ControlNet: loading ControlNet:", info)
|
|
|
| ctrl_unet.to(unet.device, dtype=unet.dtype)
|
| ctrl_net.to(unet.device, dtype=unet.dtype)
|
| return ctrl_unet, ctrl_net
|
|
|
|
|
| def load_preprocess(prep_type: str):
|
| if prep_type is None or prep_type.lower() == "none":
|
| return None
|
|
|
| if prep_type.startswith("canny"):
|
| args = prep_type.split("_")
|
| th1 = int(args[1]) if len(args) >= 2 else 63
|
| th2 = int(args[2]) if len(args) >= 3 else 191
|
|
|
| def canny(img):
|
| img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| return cv2.Canny(img, th1, th2)
|
|
|
| return canny
|
|
|
| print("Unsupported prep type:", prep_type)
|
| return None
|
|
|
|
|
| def preprocess_ctrl_net_hint_image(image):
|
| image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
|
| image = image[None].transpose(0, 3, 1, 2)
|
| image = torch.from_numpy(image)
|
| return image
|
|
|
|
|
| def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
|
| guided_hints = []
|
| for i, cnet_info in enumerate(control_nets):
|
|
|
| b_hints = []
|
| if len(hints) == 1:
|
| hint = hints[0]
|
| if cnet_info.prep is not None:
|
| hint = cnet_info.prep(hint)
|
| hint = preprocess_ctrl_net_hint_image(hint)
|
| b_hints = [hint for _ in range(b_size)]
|
| else:
|
| for bi in range(b_size):
|
| hint = hints[(bi * len(control_nets) + i) % len(hints)]
|
| if cnet_info.prep is not None:
|
| hint = cnet_info.prep(hint)
|
| hint = preprocess_ctrl_net_hint_image(hint)
|
| b_hints.append(hint)
|
| b_hints = torch.cat(b_hints, dim=0)
|
| b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
|
|
|
| guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
|
| guided_hints.append(guided_hint)
|
| return guided_hints
|
|
|
|
|
| def call_unet_and_control_net(
|
| step,
|
| num_latent_input,
|
| original_unet,
|
| control_nets: List[ControlNetInfo],
|
| guided_hints,
|
| current_ratio,
|
| sample,
|
| timestep,
|
| encoder_hidden_states,
|
| encoder_hidden_states_for_control_net,
|
| ):
|
|
|
|
|
| cnet_cnt = len(control_nets)
|
| cnet_idx = step % cnet_cnt
|
| cnet_info = control_nets[cnet_idx]
|
|
|
|
|
| if cnet_info.ratio < current_ratio:
|
| return original_unet(sample, timestep, encoder_hidden_states)
|
|
|
| guided_hint = guided_hints[cnet_idx]
|
| guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
| outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
|
| outs = [o * cnet_info.weight for o in outs]
|
|
|
|
|
| return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
|
|
|
|
|
| """
|
| # これはmergeのバージョン
|
| # ControlNet
|
| cnet_outs_list = []
|
| for i, cnet_info in enumerate(control_nets):
|
| # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
| if cnet_info.ratio < current_ratio:
|
| continue
|
| guided_hint = guided_hints[i]
|
| outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
|
| for i in range(len(outs)):
|
| outs[i] *= cnet_info.weight
|
|
|
| cnet_outs_list.append(outs)
|
|
|
| count = len(cnet_outs_list)
|
| if count == 0:
|
| return original_unet(sample, timestep, encoder_hidden_states)
|
|
|
| # sum of controlnets
|
| for i in range(1, count):
|
| cnet_outs_list[0] += cnet_outs_list[i]
|
|
|
| # U-Net
|
| return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
|
| """
|
|
|
|
|
| def unet_forward(
|
| is_control_net,
|
| control_net: ControlNet,
|
| unet: UNet2DConditionModel,
|
| guided_hint,
|
| ctrl_outs,
|
| sample,
|
| timestep,
|
| encoder_hidden_states,
|
| ):
|
|
|
| default_overall_up_factor = 2**unet.num_upsamplers
|
|
|
| forward_upsample_size = False
|
| upsample_size = None
|
|
|
| if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| print("Forward upsample size to force interpolation output size.")
|
| forward_upsample_size = True
|
|
|
|
|
| timesteps = timestep
|
| if not torch.is_tensor(timesteps):
|
|
|
|
|
| is_mps = sample.device.type == "mps"
|
| if isinstance(timestep, float):
|
| dtype = torch.float32 if is_mps else torch.float64
|
| else:
|
| dtype = torch.int32 if is_mps else torch.int64
|
| timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| elif len(timesteps.shape) == 0:
|
| timesteps = timesteps[None].to(sample.device)
|
|
|
|
|
| timesteps = timesteps.expand(sample.shape[0])
|
|
|
| t_emb = unet.time_proj(timesteps)
|
|
|
|
|
|
|
|
|
| t_emb = t_emb.to(dtype=unet.dtype)
|
| emb = unet.time_embedding(t_emb)
|
|
|
| outs = []
|
| zc_idx = 0
|
|
|
|
|
| sample = unet.conv_in(sample)
|
| if is_control_net:
|
| sample += guided_hint
|
| outs.append(control_net.control_model.zero_convs[zc_idx][0](sample))
|
| zc_idx += 1
|
|
|
|
|
| down_block_res_samples = (sample,)
|
| for downsample_block in unet.down_blocks:
|
| if downsample_block.has_cross_attention:
|
| sample, res_samples = downsample_block(
|
| hidden_states=sample,
|
| temb=emb,
|
| encoder_hidden_states=encoder_hidden_states,
|
| )
|
| else:
|
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| if is_control_net:
|
| for rs in res_samples:
|
| outs.append(control_net.control_model.zero_convs[zc_idx][0](rs))
|
| zc_idx += 1
|
|
|
| down_block_res_samples += res_samples
|
|
|
|
|
| sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
| if is_control_net:
|
| outs.append(control_net.control_model.middle_block_out[0](sample))
|
| return outs
|
|
|
| if not is_control_net:
|
| sample += ctrl_outs.pop()
|
|
|
|
|
| for i, upsample_block in enumerate(unet.up_blocks):
|
| is_final_block = i == len(unet.up_blocks) - 1
|
|
|
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
|
|
| if not is_control_net and len(ctrl_outs) > 0:
|
| res_samples = list(res_samples)
|
| apply_ctrl_outs = ctrl_outs[-len(res_samples) :]
|
| ctrl_outs = ctrl_outs[: -len(res_samples)]
|
| for j in range(len(res_samples)):
|
| res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
|
| res_samples = tuple(res_samples)
|
|
|
|
|
|
|
| if not is_final_block and forward_upsample_size:
|
| upsample_size = down_block_res_samples[-1].shape[2:]
|
|
|
| if upsample_block.has_cross_attention:
|
| sample = upsample_block(
|
| hidden_states=sample,
|
| temb=emb,
|
| res_hidden_states_tuple=res_samples,
|
| encoder_hidden_states=encoder_hidden_states,
|
| upsample_size=upsample_size,
|
| )
|
| else:
|
| sample = upsample_block(
|
| hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
| )
|
|
|
| sample = unet.conv_norm_out(sample)
|
| sample = unet.conv_act(sample)
|
| sample = unet.conv_out(sample)
|
|
|
| return SampleOutput(sample=sample)
|
|
|