File size: 10,916 Bytes
302920f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Union
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
)
from diffusers.utils import BaseOutput, logging
from torch import nn
from torch.nn import functional as F
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ControlNetOutput(BaseOutput):
down_block_res_samples: tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
class ControlNetConditioningEmbedding(nn.Module):
"""
Quoting from https://huggingface.co/papers/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: tuple[int] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = zero_module(
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class ControlNetModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
out_channels: int = 320,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[tuple[int]] = (16, 32, 96, 256),
):
super().__init__()
# for control image
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=out_channels,
block_out_channels=conditioning_embedding_out_channels,
)
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, dict[str, AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
controlnet_cond: torch.FloatTensor,
) -> Union[ControlNetOutput, tuple]:
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# 2. pre-process
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
return controlnet_cond
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
|