build-tools / diffusers /models /controlnets /controlnet_xs.py
salmankhanpm's picture
Add files using upload-large-folder tool
69e1a8d verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import gcd
from typing import Any
import torch
from torch import Tensor, nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ...utils.torch_utils import apply_freeu
from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
Downsample2D,
ResnetBlock2D,
Transformer2DModel,
UNetMidBlock2DCrossAttn,
Upsample2D,
)
from ..unets.unet_2d_condition import UNet2DConditionModel
from .controlnet import ControlNetConditioningEmbedding
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ControlNetXSOutput(BaseOutput):
"""
The output of [`UNetControlNetXSModel`].
Args:
sample (`Tensor` of shape `(batch_size, num_channels, height, width)`):
The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base
model output, but is already the final output.
"""
sample: Tensor = None
class DownBlockControlNetXSAdapter(nn.Module):
"""Components that together with corresponding components from the base model will form a
`ControlNetXSCrossAttnDownBlock2D`"""
def __init__(
self,
resnets: nn.ModuleList,
base_to_ctrl: nn.ModuleList,
ctrl_to_base: nn.ModuleList,
attentions: nn.ModuleList | None = None,
downsampler: nn.Conv2d | None = None,
):
super().__init__()
self.resnets = resnets
self.base_to_ctrl = base_to_ctrl
self.ctrl_to_base = ctrl_to_base
self.attentions = attentions
self.downsamplers = downsampler
class MidBlockControlNetXSAdapter(nn.Module):
"""Components that together with corresponding components from the base model will form a
`ControlNetXSCrossAttnMidBlock2D`"""
def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList):
super().__init__()
self.midblock = midblock
self.base_to_ctrl = base_to_ctrl
self.ctrl_to_base = ctrl_to_base
class UpBlockControlNetXSAdapter(nn.Module):
"""Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`"""
def __init__(self, ctrl_to_base: nn.ModuleList):
super().__init__()
self.ctrl_to_base = ctrl_to_base
def get_down_block_adapter(
base_in_channels: int,
base_out_channels: int,
ctrl_in_channels: int,
ctrl_out_channels: int,
temb_channels: int,
max_norm_num_groups: int | None = 32,
has_crossattn=True,
transformer_layers_per_block: int | tuple[int] | None = 1,
num_attention_heads: int | None = 1,
cross_attention_dim: int | None = 1024,
add_downsample: bool = True,
upcast_attention: bool | None = False,
use_linear_projection: bool | None = True,
):
num_layers = 2 # only support sd + sdxl
resnets = []
attentions = []
ctrl_to_base = []
base_to_ctrl = []
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
base_in_channels = base_in_channels if i == 0 else base_out_channels
ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels
# Before the resnet/attention application, information is concatted from base to control.
# Concat doesn't require change in number of channels
base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels))
resnets.append(
ResnetBlock2D(
in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl
out_channels=ctrl_out_channels,
temb_channels=temb_channels,
groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups),
groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
eps=1e-5,
)
)
if has_crossattn:
attentions.append(
Transformer2DModel(
num_attention_heads,
ctrl_out_channels // num_attention_heads,
in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
)
)
# After the resnet/attention application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
if add_downsample:
# Before the downsampler application, information is concatted from base to control
# Concat doesn't require change in number of channels
base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels))
downsamplers = Downsample2D(
ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op"
)
# After the downsampler application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
else:
downsamplers = None
down_block_components = DownBlockControlNetXSAdapter(
resnets=nn.ModuleList(resnets),
base_to_ctrl=nn.ModuleList(base_to_ctrl),
ctrl_to_base=nn.ModuleList(ctrl_to_base),
)
if has_crossattn:
down_block_components.attentions = nn.ModuleList(attentions)
if downsamplers is not None:
down_block_components.downsamplers = downsamplers
return down_block_components
def get_mid_block_adapter(
base_channels: int,
ctrl_channels: int,
temb_channels: int | None = None,
max_norm_num_groups: int | None = 32,
transformer_layers_per_block: int = 1,
num_attention_heads: int | None = 1,
cross_attention_dim: int | None = 1024,
upcast_attention: bool = False,
use_linear_projection: bool = True,
):
# Before the midblock application, information is concatted from base to control.
# Concat doesn't require change in number of channels
base_to_ctrl = make_zero_conv(base_channels, base_channels)
midblock = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
in_channels=ctrl_channels + base_channels,
out_channels=ctrl_channels,
temb_channels=temb_channels,
# number or norm groups must divide both in_channels and out_channels
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
# After the midblock application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base)
def get_up_block_adapter(
out_channels: int,
prev_output_channel: int,
ctrl_skip_channels: list[int],
):
ctrl_to_base = []
num_layers = 3 # only support sd + sdxl
for i in range(num_layers):
resnet_in_channels = prev_output_channel if i == 0 else out_channels
ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))
return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base))
class ControlNetXSAdapter(ModelMixin, AttentionMixin, ConfigMixin):
r"""
A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a
`UNet2DConditionModel` base model).
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
methods implemented for all models (such as downloading or saving).
Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's
default parameters are compatible with StableDiffusion.
Parameters:
conditioning_channels (`int`, defaults to 3):
Number of channels of conditioning input (e.g. an image)
conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
The tuple of output channels for each block in the `controlnet_cond_embedding` layer.
time_embedding_mix (`float`, defaults to 1.0):
If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time
embedding is used. Otherwise, both are combined.
learn_time_embedding (`bool`, defaults to `False`):
Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time
embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base
model's time embedding.
num_attention_heads (`list[int]`, defaults to `[4]`):
The number of attention heads.
block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`):
The tuple of output channels for each block.
base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`):
The tuple of output channels for each block in the base unet.
cross_attention_dim (`int`, defaults to 1024):
The dimension of the cross attention features.
down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`):
The tuple of downsample blocks to use.
sample_size (`int`, defaults to 96):
Height and width of input/output sample.
transformer_layers_per_block (`int | tuple[int]`, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
upcast_attention (`bool`, defaults to `True`):
Whether the attention computation should always be upcasted.
max_norm_num_groups (`int`, defaults to 32):
Maximum number of groups in group normal. The actual number will be the largest divisor of the respective
channels, that is <= max_norm_num_groups.
"""
@register_to_config
def __init__(
self,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: tuple[int] = (16, 32, 96, 256),
time_embedding_mix: float = 1.0,
learn_time_embedding: bool = False,
num_attention_heads: int | tuple[int] = 4,
block_out_channels: tuple[int] = (4, 8, 16, 16),
base_block_out_channels: tuple[int] = (320, 640, 1280, 1280),
cross_attention_dim: int = 1024,
down_block_types: tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
sample_size: int | None = 96,
transformer_layers_per_block: int | tuple[int] = 1,
upcast_attention: bool = True,
max_norm_num_groups: int = 32,
use_linear_projection: bool = True,
):
super().__init__()
time_embedding_input_dim = base_block_out_channels[0]
time_embedding_dim = base_block_out_channels[0] * 4
# Check inputs
if conditioning_channel_order not in ["rgb", "bgr"]:
raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}")
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(transformer_layers_per_block, (list, tuple)):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if not isinstance(cross_attention_dim, (list, tuple)):
cross_attention_dim = [cross_attention_dim] * len(down_block_types)
# see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim`
if not isinstance(num_attention_heads, (list, tuple)):
num_attention_heads = [num_attention_heads] * len(down_block_types)
if len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# 5 - Create conditioning hint embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
# time
if learn_time_embedding:
self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim)
else:
self.time_embedding = None
self.down_blocks = nn.ModuleList([])
self.up_connections = nn.ModuleList([])
# input
self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1)
self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0])
# down
base_out_channels = base_block_out_channels[0]
ctrl_out_channels = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
base_in_channels = base_out_channels
base_out_channels = base_block_out_channels[i]
ctrl_in_channels = ctrl_out_channels
ctrl_out_channels = block_out_channels[i]
has_crossattn = "CrossAttn" in down_block_type
is_final_block = i == len(down_block_types) - 1
self.down_blocks.append(
get_down_block_adapter(
base_in_channels=base_in_channels,
base_out_channels=base_out_channels,
ctrl_in_channels=ctrl_in_channels,
ctrl_out_channels=ctrl_out_channels,
temb_channels=time_embedding_dim,
max_norm_num_groups=max_norm_num_groups,
has_crossattn=has_crossattn,
transformer_layers_per_block=transformer_layers_per_block[i],
num_attention_heads=num_attention_heads[i],
cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
)
# mid
self.mid_block = get_mid_block_adapter(
base_channels=base_block_out_channels[-1],
ctrl_channels=block_out_channels[-1],
temb_channels=time_embedding_dim,
transformer_layers_per_block=transformer_layers_per_block[-1],
num_attention_heads=num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# up
# The skip connection channels are the output of the conv_in and of all the down subblocks
ctrl_skip_channels = [block_out_channels[0]]
for i, out_channels in enumerate(block_out_channels):
number_of_subblocks = (
3 if i < len(block_out_channels) - 1 else 2
) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler
ctrl_skip_channels.extend([out_channels] * number_of_subblocks)
reversed_base_block_out_channels = list(reversed(base_block_out_channels))
base_out_channels = reversed_base_block_out_channels[0]
for i in range(len(down_block_types)):
prev_base_output_channel = base_out_channels
base_out_channels = reversed_base_block_out_channels[i]
ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)]
self.up_connections.append(
get_up_block_adapter(
out_channels=base_out_channels,
prev_output_channel=prev_base_output_channel,
ctrl_skip_channels=ctrl_skip_channels_,
)
)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
size_ratio: float | None = None,
block_out_channels: list[int] | None = None,
num_attention_heads: list[int] | None = None,
learn_time_embedding: bool = False,
time_embedding_mix: int = 1.0,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: tuple[int] = (16, 32, 96, 256),
):
r"""
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it.
size_ratio (float, *optional*, defaults to `None`):
When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this
or `block_out_channels` must be given.
block_out_channels (`list[int]`, *optional*, defaults to `None`):
Down blocks output channels in control model. Either this or `size_ratio` must be given.
num_attention_heads (`list[int]`, *optional*, defaults to `None`):
The dimension of the attention heads. The naming seems a bit confusing and it is, see
https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
learn_time_embedding (`bool`, defaults to `False`):
Whether the `ControlNetXSAdapter` should learn a time embedding.
time_embedding_mix (`float`, defaults to 1.0):
If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time
embedding is used. Otherwise, both are combined.
conditioning_channels (`int`, defaults to 3):
Number of channels of conditioning input (e.g. an image)
conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
"""
# Check input
fixed_size = block_out_channels is not None
relative_size = size_ratio is not None
if not (fixed_size ^ relative_size):
raise ValueError(
"Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)."
)
# Create model
block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels]
if num_attention_heads is None:
# The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
num_attention_heads = unet.config.attention_head_dim
model = cls(
conditioning_channels=conditioning_channels,
conditioning_channel_order=conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
time_embedding_mix=time_embedding_mix,
learn_time_embedding=learn_time_embedding,
num_attention_heads=num_attention_heads,
block_out_channels=block_out_channels,
base_block_out_channels=unet.config.block_out_channels,
cross_attention_dim=unet.config.cross_attention_dim,
down_block_types=unet.config.down_block_types,
sample_size=unet.config.sample_size,
transformer_layers_per_block=unet.config.transformer_layers_per_block,
upcast_attention=unet.config.upcast_attention,
max_norm_num_groups=unet.config.norm_num_groups,
use_linear_projection=unet.config.use_linear_projection,
)
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
model.to(unet.dtype)
return model
def forward(self, *args, **kwargs):
raise ValueError(
"A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel."
)
class UNetControlNetXSModel(ModelMixin, AttentionMixin, ConfigMixin):
r"""
A UNet fused with a ControlNet-XS adapter model
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
methods implemented for all models (such as downloading or saving).
`UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are
compatible with StableDiffusion.
It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in
`ControlNetXSAdapter` . See their documentation for details.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
# unet configs
sample_size: int | None = 96,
down_block_types: tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types: tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels: tuple[int] = (320, 640, 1280, 1280),
norm_num_groups: int | None = 32,
cross_attention_dim: int | tuple[int] = 1024,
transformer_layers_per_block: int | tuple[int] = 1,
num_attention_heads: int | tuple[int] = 8,
addition_embed_type: str | None = None,
addition_time_embed_dim: int | None = None,
upcast_attention: bool = True,
use_linear_projection: bool = True,
time_cond_proj_dim: int | None = None,
projection_class_embeddings_input_dim: int | None = None,
# additional controlnet configs
time_embedding_mix: float = 1.0,
ctrl_conditioning_channels: int = 3,
ctrl_conditioning_embedding_out_channels: tuple[int] = (16, 32, 96, 256),
ctrl_conditioning_channel_order: str = "rgb",
ctrl_learn_time_embedding: bool = False,
ctrl_block_out_channels: tuple[int] = (4, 8, 16, 16),
ctrl_num_attention_heads: int | tuple[int] = 4,
ctrl_max_norm_num_groups: int = 32,
):
super().__init__()
if time_embedding_mix < 0 or time_embedding_mix > 1:
raise ValueError("`time_embedding_mix` needs to be between 0 and 1.")
if time_embedding_mix < 1 and not ctrl_learn_time_embedding:
raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`")
if addition_embed_type is not None and addition_embed_type != "text_time":
raise ValueError(
"As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`."
)
if not isinstance(transformer_layers_per_block, (list, tuple)):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if not isinstance(cross_attention_dim, (list, tuple)):
cross_attention_dim = [cross_attention_dim] * len(down_block_types)
if not isinstance(num_attention_heads, (list, tuple)):
num_attention_heads = [num_attention_heads] * len(down_block_types)
if not isinstance(ctrl_num_attention_heads, (list, tuple)):
ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types)
base_num_attention_heads = num_attention_heads
self.in_channels = 4
# # Input
self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1)
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=ctrl_block_out_channels[0],
block_out_channels=ctrl_conditioning_embedding_out_channels,
conditioning_channels=ctrl_conditioning_channels,
)
self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1)
self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0])
# # Time
time_embed_input_dim = block_out_channels[0]
time_embed_dim = block_out_channels[0] * 4
self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0)
self.base_time_embedding = TimestepEmbedding(
time_embed_input_dim,
time_embed_dim,
cond_proj_dim=time_cond_proj_dim,
)
if ctrl_learn_time_embedding:
self.ctrl_time_embedding = TimestepEmbedding(
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
)
else:
self.ctrl_time_embedding = None
if addition_embed_type is None:
self.base_add_time_proj = None
self.base_add_embedding = None
else:
self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
# # Create down blocks
down_blocks = []
base_out_channels = block_out_channels[0]
ctrl_out_channels = ctrl_block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
base_in_channels = base_out_channels
base_out_channels = block_out_channels[i]
ctrl_in_channels = ctrl_out_channels
ctrl_out_channels = ctrl_block_out_channels[i]
has_crossattn = "CrossAttn" in down_block_type
is_final_block = i == len(down_block_types) - 1
down_blocks.append(
ControlNetXSCrossAttnDownBlock2D(
base_in_channels=base_in_channels,
base_out_channels=base_out_channels,
ctrl_in_channels=ctrl_in_channels,
ctrl_out_channels=ctrl_out_channels,
temb_channels=time_embed_dim,
norm_num_groups=norm_num_groups,
ctrl_max_norm_num_groups=ctrl_max_norm_num_groups,
has_crossattn=has_crossattn,
transformer_layers_per_block=transformer_layers_per_block[i],
base_num_attention_heads=base_num_attention_heads[i],
ctrl_num_attention_heads=ctrl_num_attention_heads[i],
cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
)
# # Create mid block
self.mid_block = ControlNetXSCrossAttnMidBlock2D(
base_channels=block_out_channels[-1],
ctrl_channels=ctrl_block_out_channels[-1],
temb_channels=time_embed_dim,
norm_num_groups=norm_num_groups,
ctrl_max_norm_num_groups=ctrl_max_norm_num_groups,
transformer_layers_per_block=transformer_layers_per_block[-1],
base_num_attention_heads=base_num_attention_heads[-1],
ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# # Create up blocks
up_blocks = []
rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
rev_num_attention_heads = list(reversed(base_num_attention_heads))
rev_cross_attention_dim = list(reversed(cross_attention_dim))
# The skip connection channels are the output of the conv_in and of all the down subblocks
ctrl_skip_channels = [ctrl_block_out_channels[0]]
for i, out_channels in enumerate(ctrl_block_out_channels):
number_of_subblocks = (
3 if i < len(ctrl_block_out_channels) - 1 else 2
) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler
ctrl_skip_channels.extend([out_channels] * number_of_subblocks)
reversed_block_out_channels = list(reversed(block_out_channels))
out_channels = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = out_channels
out_channels = reversed_block_out_channels[i]
in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)]
has_crossattn = "CrossAttn" in up_block_type
is_final_block = i == len(block_out_channels) - 1
up_blocks.append(
ControlNetXSCrossAttnUpBlock2D(
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
ctrl_skip_channels=ctrl_skip_channels_,
temb_channels=time_embed_dim,
resolution_idx=i,
has_crossattn=has_crossattn,
transformer_layers_per_block=rev_transformer_layers_per_block[i],
num_attention_heads=rev_num_attention_heads[i],
cross_attention_dim=rev_cross_attention_dim[i],
add_upsample=not is_final_block,
upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
)
)
self.down_blocks = nn.ModuleList(down_blocks)
self.up_blocks = nn.ModuleList(up_blocks)
self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups)
self.base_conv_act = nn.SiLU()
self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet: ControlNetXSAdapter | None = None,
size_ratio: float | None = None,
ctrl_block_out_channels: list[float] | None = None,
time_embedding_mix: float | None = None,
ctrl_optional_kwargs: dict | None = None,
):
r"""
Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`]
.
Parameters:
unet (`UNet2DConditionModel`):
The UNet model we want to control.
controlnet (`ControlNetXSAdapter`):
The ControlNet-XS adapter with which the UNet will be fused. If none is given, a new ControlNet-XS
adapter will be created.
size_ratio (float, *optional*, defaults to `None`):
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
ctrl_block_out_channels (`list[int]`, *optional*, defaults to `None`):
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
where this parameter is called `block_out_channels`.
time_embedding_mix (`float`, *optional*, defaults to None):
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
Passed to the `init` of the new controlnet if no controlnet was given.
"""
if controlnet is None:
controlnet = ControlNetXSAdapter.from_unet(
unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs
)
else:
if any(
o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs)
):
raise ValueError(
"When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs."
)
# # get params
params_for_unet = [
"sample_size",
"down_block_types",
"up_block_types",
"block_out_channels",
"norm_num_groups",
"cross_attention_dim",
"transformer_layers_per_block",
"addition_embed_type",
"addition_time_embed_dim",
"upcast_attention",
"use_linear_projection",
"time_cond_proj_dim",
"projection_class_embeddings_input_dim",
]
params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet}
# The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
params_for_unet["num_attention_heads"] = unet.config.attention_head_dim
params_for_controlnet = [
"conditioning_channels",
"conditioning_embedding_out_channels",
"conditioning_channel_order",
"learn_time_embedding",
"block_out_channels",
"num_attention_heads",
"max_norm_num_groups",
]
params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet}
params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix
# # create model
model = cls.from_config({**params_for_unet, **params_for_controlnet})
# # load weights
# from unet
modules_from_unet = [
"time_embedding",
"conv_in",
"conv_norm_out",
"conv_out",
]
for m in modules_from_unet:
getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict())
optional_modules_from_unet = [
"add_time_proj",
"add_embedding",
]
for m in optional_modules_from_unet:
if hasattr(unet, m) and getattr(unet, m) is not None:
getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict())
# from controlnet
model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict())
model.ctrl_conv_in.load_state_dict(controlnet.conv_in.state_dict())
if controlnet.time_embedding is not None:
model.ctrl_time_embedding.load_state_dict(controlnet.time_embedding.state_dict())
model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict())
# from both
model.down_blocks = nn.ModuleList(
ControlNetXSCrossAttnDownBlock2D.from_modules(b, c)
for b, c in zip(unet.down_blocks, controlnet.down_blocks)
)
model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block)
model.up_blocks = nn.ModuleList(
ControlNetXSCrossAttnUpBlock2D.from_modules(b, c)
for b, c in zip(unet.up_blocks, controlnet.up_connections)
)
# ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel
model.to(unet.dtype)
return model
def freeze_unet_params(self) -> None:
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
tuning."""
# Freeze everything
for param in self.parameters():
param.requires_grad = True
# Unfreeze ControlNetXSAdapter
base_parts = [
"base_time_proj",
"base_time_embedding",
"base_add_time_proj",
"base_add_embedding",
"base_conv_in",
"base_conv_norm_out",
"base_conv_act",
"base_conv_out",
]
base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None]
for part in base_parts:
for param in part.parameters():
param.requires_grad = False
for d in self.down_blocks:
d.freeze_base_params()
self.mid_block.freeze_base_params()
for u in self.up_blocks:
u.freeze_base_params()
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
sample: Tensor,
timestep: torch.Tensor | float | int,
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor | None = None,
conditioning_scale: float | None = 1.0,
class_labels: torch.Tensor | None = None,
timestep_cond: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
cross_attention_kwargs: dict[str, Any] | None = None,
added_cond_kwargs: dict[str, torch.Tensor] | None = None,
return_dict: bool = True,
apply_control: bool = True,
) -> ControlNetXSOutput | tuple:
"""
The [`ControlNetXSModel`] forward method.
Args:
sample (`Tensor`):
The noisy input tensor.
timestep (`torch.Tensor | float | int`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
How much the control model affects the base model outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain
tuple.
apply_control (`bool`, defaults to `True`):
If `False`, the input is run only through the base model.
Returns:
[`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
# check channel order
if self.config.ctrl_conditioning_channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.base_time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
if self.config.ctrl_learn_time_embedding and apply_control:
ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond)
base_temb = self.base_time_embedding(t_emb, timestep_cond)
interpolation_param = self.config.time_embedding_mix**0.3
temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
else:
temb = self.base_time_embedding(t_emb)
# added time & text embeddings
aug_emb = None
if self.config.addition_embed_type is None:
pass
elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.base_add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(temb.dtype)
aug_emb = self.base_add_embedding(add_embeds)
else:
raise ValueError(
f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported."
)
temb = temb + aug_emb if aug_emb is not None else temb
# text embeddings
cemb = encoder_hidden_states
# Preparation
h_ctrl = h_base = sample
hs_base, hs_ctrl = [], []
# Cross Control
guided_hint = self.controlnet_cond_embedding(controlnet_cond)
# 1 - conv in & down
h_base = self.base_conv_in(h_base)
h_ctrl = self.ctrl_conv_in(h_ctrl)
if guided_hint is not None:
h_ctrl += guided_hint
if apply_control:
h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base
hs_base.append(h_base)
hs_ctrl.append(h_ctrl)
for down in self.down_blocks:
h_base, h_ctrl, residual_hb, residual_hc = down(
hidden_states_base=h_base,
hidden_states_ctrl=h_ctrl,
temb=temb,
encoder_hidden_states=cemb,
conditioning_scale=conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
apply_control=apply_control,
)
hs_base.extend(residual_hb)
hs_ctrl.extend(residual_hc)
# 2 - mid
h_base, h_ctrl = self.mid_block(
hidden_states_base=h_base,
hidden_states_ctrl=h_ctrl,
temb=temb,
encoder_hidden_states=cemb,
conditioning_scale=conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
apply_control=apply_control,
)
# 3 - up
for up in self.up_blocks:
n_resnets = len(up.resnets)
skips_hb = hs_base[-n_resnets:]
skips_hc = hs_ctrl[-n_resnets:]
hs_base = hs_base[:-n_resnets]
hs_ctrl = hs_ctrl[:-n_resnets]
h_base = up(
hidden_states=h_base,
res_hidden_states_tuple_base=skips_hb,
res_hidden_states_tuple_ctrl=skips_hc,
temb=temb,
encoder_hidden_states=cemb,
conditioning_scale=conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
apply_control=apply_control,
)
# 4 - conv out
h_base = self.base_conv_norm_out(h_base)
h_base = self.base_conv_act(h_base)
h_base = self.base_conv_out(h_base)
if not return_dict:
return (h_base,)
return ControlNetXSOutput(sample=h_base)
class ControlNetXSCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
base_in_channels: int,
base_out_channels: int,
ctrl_in_channels: int,
ctrl_out_channels: int,
temb_channels: int,
norm_num_groups: int = 32,
ctrl_max_norm_num_groups: int = 32,
has_crossattn=True,
transformer_layers_per_block: int | tuple[int] | None = 1,
base_num_attention_heads: int | None = 1,
ctrl_num_attention_heads: int | None = 1,
cross_attention_dim: int | None = 1024,
add_downsample: bool = True,
upcast_attention: bool | None = False,
use_linear_projection: bool | None = True,
):
super().__init__()
base_resnets = []
base_attentions = []
ctrl_resnets = []
ctrl_attentions = []
ctrl_to_base = []
base_to_ctrl = []
num_layers = 2 # only support sd + sdxl
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
base_in_channels = base_in_channels if i == 0 else base_out_channels
ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels
# Before the resnet/attention application, information is concatted from base to control.
# Concat doesn't require change in number of channels
base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels))
base_resnets.append(
ResnetBlock2D(
in_channels=base_in_channels,
out_channels=base_out_channels,
temb_channels=temb_channels,
groups=norm_num_groups,
)
)
ctrl_resnets.append(
ResnetBlock2D(
in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl
out_channels=ctrl_out_channels,
temb_channels=temb_channels,
groups=find_largest_factor(
ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups
),
groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
eps=1e-5,
)
)
if has_crossattn:
base_attentions.append(
Transformer2DModel(
base_num_attention_heads,
base_out_channels // base_num_attention_heads,
in_channels=base_out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups,
)
)
ctrl_attentions.append(
Transformer2DModel(
ctrl_num_attention_heads,
ctrl_out_channels // ctrl_num_attention_heads,
in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
)
)
# After the resnet/attention application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
if add_downsample:
# Before the downsampler application, information is concatted from base to control
# Concat doesn't require change in number of channels
base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels))
self.base_downsamplers = Downsample2D(
base_out_channels, use_conv=True, out_channels=base_out_channels, name="op"
)
self.ctrl_downsamplers = Downsample2D(
ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op"
)
# After the downsampler application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
else:
self.base_downsamplers = None
self.ctrl_downsamplers = None
self.base_resnets = nn.ModuleList(base_resnets)
self.ctrl_resnets = nn.ModuleList(ctrl_resnets)
self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers
self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers
self.base_to_ctrl = nn.ModuleList(base_to_ctrl)
self.ctrl_to_base = nn.ModuleList(ctrl_to_base)
self.gradient_checkpointing = False
@classmethod
def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter):
# get params
def get_first_cross_attention(block):
return block.attentions[0].transformer_blocks[0].attn2
base_in_channels = base_downblock.resnets[0].in_channels
base_out_channels = base_downblock.resnets[0].out_channels
ctrl_in_channels = (
ctrl_downblock.resnets[0].in_channels - base_in_channels
) # base channels are concatted to ctrl channels in init
ctrl_out_channels = ctrl_downblock.resnets[0].out_channels
temb_channels = base_downblock.resnets[0].time_emb_proj.in_features
num_groups = base_downblock.resnets[0].norm1.num_groups
ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups
if hasattr(base_downblock, "attentions"):
has_crossattn = True
transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks)
base_num_attention_heads = get_first_cross_attention(base_downblock).heads
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
use_linear_projection = base_downblock.attentions[0].use_linear_projection
else:
has_crossattn = False
transformer_layers_per_block = None
base_num_attention_heads = None
ctrl_num_attention_heads = None
cross_attention_dim = None
upcast_attention = None
use_linear_projection = None
add_downsample = base_downblock.downsamplers is not None
# create model
model = cls(
base_in_channels=base_in_channels,
base_out_channels=base_out_channels,
ctrl_in_channels=ctrl_in_channels,
ctrl_out_channels=ctrl_out_channels,
temb_channels=temb_channels,
norm_num_groups=num_groups,
ctrl_max_norm_num_groups=ctrl_num_groups,
has_crossattn=has_crossattn,
transformer_layers_per_block=transformer_layers_per_block,
base_num_attention_heads=base_num_attention_heads,
ctrl_num_attention_heads=ctrl_num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_downsample=add_downsample,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# # load weights
model.base_resnets.load_state_dict(base_downblock.resnets.state_dict())
model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict())
if has_crossattn:
model.base_attentions.load_state_dict(base_downblock.attentions.state_dict())
model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict())
if add_downsample:
model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict())
model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict())
model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict())
model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict())
return model
def freeze_base_params(self) -> None:
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
tuning."""
# Unfreeze everything
for param in self.parameters():
param.requires_grad = True
# Freeze base part
base_parts = [self.base_resnets]
if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones
base_parts.append(self.base_attentions)
if self.base_downsamplers is not None:
base_parts.append(self.base_downsamplers)
for part in base_parts:
for param in part.parameters():
param.requires_grad = False
def forward(
self,
hidden_states_base: Tensor,
temb: Tensor,
encoder_hidden_states: Tensor | None = None,
hidden_states_ctrl: Tensor | None = None,
conditioning_scale: float | None = 1.0,
attention_mask: Tensor | None = None,
cross_attention_kwargs: dict[str, Any] | None = None,
encoder_attention_mask: Tensor | None = None,
apply_control: bool = True,
) -> tuple[Tensor, Tensor, tuple[Tensor, ...], tuple[Tensor, ...]]:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
h_base = hidden_states_base
h_ctrl = hidden_states_ctrl
base_output_states = ()
ctrl_output_states = ()
base_blocks = list(zip(self.base_resnets, self.base_attentions))
ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
):
# concat base -> ctrl
if apply_control:
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
# apply base subblock
if torch.is_grad_enabled() and self.gradient_checkpointing:
h_base = self._gradient_checkpointing_func(b_res, h_base, temb)
else:
h_base = b_res(h_base, temb)
if b_attn is not None:
h_base = b_attn(
h_base,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# apply ctrl subblock
if apply_control:
if torch.is_grad_enabled() and self.gradient_checkpointing:
h_ctrl = self._gradient_checkpointing_func(c_res, h_ctrl, temb)
else:
h_ctrl = c_res(h_ctrl, temb)
if c_attn is not None:
h_ctrl = c_attn(
h_ctrl,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# add ctrl -> base
if apply_control:
h_base = h_base + c2b(h_ctrl) * conditioning_scale
base_output_states = base_output_states + (h_base,)
ctrl_output_states = ctrl_output_states + (h_ctrl,)
if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler
b2c = self.base_to_ctrl[-1]
c2b = self.ctrl_to_base[-1]
# concat base -> ctrl
if apply_control:
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
# apply base subblock
h_base = self.base_downsamplers(h_base)
# apply ctrl subblock
if apply_control:
h_ctrl = self.ctrl_downsamplers(h_ctrl)
# add ctrl -> base
if apply_control:
h_base = h_base + c2b(h_ctrl) * conditioning_scale
base_output_states = base_output_states + (h_base,)
ctrl_output_states = ctrl_output_states + (h_ctrl,)
return h_base, h_ctrl, base_output_states, ctrl_output_states
class ControlNetXSCrossAttnMidBlock2D(nn.Module):
def __init__(
self,
base_channels: int,
ctrl_channels: int,
temb_channels: int | None = None,
norm_num_groups: int = 32,
ctrl_max_norm_num_groups: int = 32,
transformer_layers_per_block: int = 1,
base_num_attention_heads: int | None = 1,
ctrl_num_attention_heads: int | None = 1,
cross_attention_dim: int | None = 1024,
upcast_attention: bool = False,
use_linear_projection: bool | None = True,
):
super().__init__()
# Before the midblock application, information is concatted from base to control.
# Concat doesn't require change in number of channels
self.base_to_ctrl = make_zero_conv(base_channels, base_channels)
self.base_midblock = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
in_channels=base_channels,
temb_channels=temb_channels,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=base_num_attention_heads,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
self.ctrl_midblock = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
in_channels=ctrl_channels + base_channels,
out_channels=ctrl_channels,
temb_channels=temb_channels,
# number or norm groups must divide both in_channels and out_channels
resnet_groups=find_largest_factor(
gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups
),
cross_attention_dim=cross_attention_dim,
num_attention_heads=ctrl_num_attention_heads,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
# After the midblock application, information is added from control to base
# Addition requires change in number of channels
self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
self.gradient_checkpointing = False
@classmethod
def from_modules(
cls,
base_midblock: UNetMidBlock2DCrossAttn,
ctrl_midblock: MidBlockControlNetXSAdapter,
):
base_to_ctrl = ctrl_midblock.base_to_ctrl
ctrl_to_base = ctrl_midblock.ctrl_to_base
ctrl_midblock = ctrl_midblock.midblock
# get params
def get_first_cross_attention(midblock):
return midblock.attentions[0].transformer_blocks[0].attn2
base_channels = ctrl_to_base.out_channels
ctrl_channels = ctrl_to_base.in_channels
transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks)
temb_channels = base_midblock.resnets[0].time_emb_proj.in_features
num_groups = base_midblock.resnets[0].norm1.num_groups
ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups
base_num_attention_heads = get_first_cross_attention(base_midblock).heads
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
use_linear_projection = base_midblock.attentions[0].use_linear_projection
# create model
model = cls(
base_channels=base_channels,
ctrl_channels=ctrl_channels,
temb_channels=temb_channels,
norm_num_groups=num_groups,
ctrl_max_norm_num_groups=ctrl_num_groups,
transformer_layers_per_block=transformer_layers_per_block,
base_num_attention_heads=base_num_attention_heads,
ctrl_num_attention_heads=ctrl_num_attention_heads,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# load weights
model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict())
model.base_midblock.load_state_dict(base_midblock.state_dict())
model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict())
model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict())
return model
def freeze_base_params(self) -> None:
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
tuning."""
# Unfreeze everything
for param in self.parameters():
param.requires_grad = True
# Freeze base part
for param in self.base_midblock.parameters():
param.requires_grad = False
def forward(
self,
hidden_states_base: Tensor,
temb: Tensor,
encoder_hidden_states: Tensor,
hidden_states_ctrl: Tensor | None = None,
conditioning_scale: float | None = 1.0,
cross_attention_kwargs: dict[str, Any] | None = None,
attention_mask: Tensor | None = None,
encoder_attention_mask: Tensor | None = None,
apply_control: bool = True,
) -> tuple[Tensor, Tensor]:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
h_base = hidden_states_base
h_ctrl = hidden_states_ctrl
joint_args = {
"temb": temb,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"cross_attention_kwargs": cross_attention_kwargs,
"encoder_attention_mask": encoder_attention_mask,
}
if apply_control:
h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl
h_base = self.base_midblock(h_base, **joint_args) # apply base mid block
if apply_control:
h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block
h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base
return h_base, h_ctrl
class ControlNetXSCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
ctrl_skip_channels: list[int],
temb_channels: int,
norm_num_groups: int = 32,
resolution_idx: int | None = None,
has_crossattn=True,
transformer_layers_per_block: int = 1,
num_attention_heads: int = 1,
cross_attention_dim: int = 1024,
add_upsample: bool = True,
upcast_attention: bool = False,
use_linear_projection: bool | None = True,
):
super().__init__()
resnets = []
attentions = []
ctrl_to_base = []
num_layers = 3 # only support sd + sdxl
self.has_cross_attention = has_crossattn
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
groups=norm_num_groups,
)
)
if has_crossattn:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups,
)
)
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers
self.ctrl_to_base = nn.ModuleList(ctrl_to_base)
if add_upsample:
self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels)
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
@classmethod
def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter):
ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base
# get params
def get_first_cross_attention(block):
return block.attentions[0].transformer_blocks[0].attn2
out_channels = base_upblock.resnets[0].out_channels
in_channels = base_upblock.resnets[-1].in_channels - out_channels
prev_output_channels = base_upblock.resnets[0].in_channels - out_channels
ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections]
temb_channels = base_upblock.resnets[0].time_emb_proj.in_features
num_groups = base_upblock.resnets[0].norm1.num_groups
resolution_idx = base_upblock.resolution_idx
if hasattr(base_upblock, "attentions"):
has_crossattn = True
transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks)
num_attention_heads = get_first_cross_attention(base_upblock).heads
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
use_linear_projection = base_upblock.attentions[0].use_linear_projection
else:
has_crossattn = False
transformer_layers_per_block = None
num_attention_heads = None
cross_attention_dim = None
upcast_attention = None
use_linear_projection = None
add_upsample = base_upblock.upsamplers is not None
# create model
model = cls(
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channels,
ctrl_skip_channels=ctrl_skip_channelss,
temb_channels=temb_channels,
norm_num_groups=num_groups,
resolution_idx=resolution_idx,
has_crossattn=has_crossattn,
transformer_layers_per_block=transformer_layers_per_block,
num_attention_heads=num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_upsample=add_upsample,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# load weights
model.resnets.load_state_dict(base_upblock.resnets.state_dict())
if has_crossattn:
model.attentions.load_state_dict(base_upblock.attentions.state_dict())
if add_upsample:
model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict())
model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict())
return model
def freeze_base_params(self) -> None:
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
tuning."""
# Unfreeze everything
for param in self.parameters():
param.requires_grad = True
# Freeze base part
base_parts = [self.resnets]
if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones
base_parts.append(self.attentions)
if self.upsamplers is not None:
base_parts.append(self.upsamplers)
for part in base_parts:
for param in part.parameters():
param.requires_grad = False
def forward(
self,
hidden_states: Tensor,
res_hidden_states_tuple_base: tuple[Tensor, ...],
res_hidden_states_tuple_ctrl: tuple[Tensor, ...],
temb: Tensor,
encoder_hidden_states: Tensor | None = None,
conditioning_scale: float | None = 1.0,
cross_attention_kwargs: dict[str, Any] | None = None,
attention_mask: Tensor | None = None,
upsample_size: int | None = None,
encoder_attention_mask: Tensor | None = None,
apply_control: bool = True,
) -> Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
return apply_freeu(
self.resolution_idx,
hidden_states,
res_h_base,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
else:
return hidden_states, res_h_base
for resnet, attn, c2b, res_h_base, res_h_ctrl in zip(
self.resnets,
self.attentions,
self.ctrl_to_base,
reversed(res_hidden_states_tuple_base),
reversed(res_hidden_states_tuple_ctrl),
):
if apply_control:
hidden_states += c2b(res_h_ctrl) * conditioning_scale
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
if attn is not None:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.upsamplers is not None:
hidden_states = self.upsamplers(hidden_states, upsample_size)
return hidden_states
def make_zero_conv(in_channels, out_channels=None):
return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
def find_largest_factor(number, max_factor):
factor = max_factor
if factor >= number:
return number
while factor != 0:
residual = number % factor
if residual == 0:
return factor
factor -= 1