| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import copy |
| | import unittest |
| |
|
| | import numpy as np |
| | import torch |
| | from torch import nn |
| |
|
| | from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel |
| | from diffusers.utils import logging |
| | from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device |
| |
|
| | from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | enable_full_determinism() |
| |
|
| |
|
| | class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): |
| | model_class = UNetControlNetXSModel |
| | main_input_name = "sample" |
| |
|
| | @property |
| | def dummy_input(self): |
| | batch_size = 4 |
| | num_channels = 4 |
| | sizes = (16, 16) |
| | conditioning_image_size = (3, 32, 32) |
| |
|
| | noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) |
| | time_step = torch.tensor([10]).to(torch_device) |
| | encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) |
| | controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device) |
| | conditioning_scale = 1 |
| |
|
| | return { |
| | "sample": noise, |
| | "timestep": time_step, |
| | "encoder_hidden_states": encoder_hidden_states, |
| | "controlnet_cond": controlnet_cond, |
| | "conditioning_scale": conditioning_scale, |
| | } |
| |
|
| | @property |
| | def input_shape(self): |
| | return (4, 16, 16) |
| |
|
| | @property |
| | def output_shape(self): |
| | return (4, 16, 16) |
| |
|
| | def prepare_init_args_and_inputs_for_common(self): |
| | init_dict = { |
| | "sample_size": 16, |
| | "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), |
| | "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), |
| | "block_out_channels": (4, 8), |
| | "cross_attention_dim": 8, |
| | "transformer_layers_per_block": 1, |
| | "num_attention_heads": 2, |
| | "norm_num_groups": 4, |
| | "upcast_attention": False, |
| | "ctrl_block_out_channels": [2, 4], |
| | "ctrl_num_attention_heads": 4, |
| | "ctrl_max_norm_num_groups": 2, |
| | "ctrl_conditioning_embedding_out_channels": (2, 2), |
| | } |
| | inputs_dict = self.dummy_input |
| | return init_dict, inputs_dict |
| |
|
| | def get_dummy_unet(self): |
| | """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" |
| | return UNet2DConditionModel( |
| | block_out_channels=(4, 8), |
| | layers_per_block=2, |
| | sample_size=16, |
| | in_channels=4, |
| | out_channels=4, |
| | down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), |
| | up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), |
| | cross_attention_dim=8, |
| | norm_num_groups=4, |
| | use_linear_projection=True, |
| | ) |
| |
|
| | def get_dummy_controlnet_from_unet(self, unet, **kwargs): |
| | """For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" |
| | |
| | return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs) |
| |
|
| | def test_from_unet(self): |
| | unet = self.get_dummy_unet() |
| | controlnet = self.get_dummy_controlnet_from_unet(unet) |
| |
|
| | model = UNetControlNetXSModel.from_unet(unet, controlnet) |
| | model_state_dict = model.state_dict() |
| |
|
| | def assert_equal_weights(module, weight_dict_prefix): |
| | for param_name, param_value in module.named_parameters(): |
| | assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value) |
| |
|
| | |
| | |
| | modules_from_unet = [ |
| | "time_embedding", |
| | "conv_in", |
| | "conv_norm_out", |
| | "conv_out", |
| | ] |
| | for p in modules_from_unet: |
| | assert_equal_weights(getattr(unet, p), "base_" + p) |
| | optional_modules_from_unet = [ |
| | "class_embedding", |
| | "add_time_proj", |
| | "add_embedding", |
| | ] |
| | for p in optional_modules_from_unet: |
| | if hasattr(unet, p) and getattr(unet, p) is not None: |
| | assert_equal_weights(getattr(unet, p), "base_" + p) |
| | |
| | assert len(unet.down_blocks) == len(model.down_blocks) |
| | for i, d in enumerate(unet.down_blocks): |
| | assert_equal_weights(d.resnets, f"down_blocks.{i}.base_resnets") |
| | if hasattr(d, "attentions"): |
| | assert_equal_weights(d.attentions, f"down_blocks.{i}.base_attentions") |
| | if hasattr(d, "downsamplers") and getattr(d, "downsamplers") is not None: |
| | assert_equal_weights(d.downsamplers[0], f"down_blocks.{i}.base_downsamplers") |
| | |
| | assert_equal_weights(unet.mid_block, "mid_block.base_midblock") |
| | |
| | assert len(unet.up_blocks) == len(model.up_blocks) |
| | for i, u in enumerate(unet.up_blocks): |
| | assert_equal_weights(u.resnets, f"up_blocks.{i}.resnets") |
| | if hasattr(u, "attentions"): |
| | assert_equal_weights(u.attentions, f"up_blocks.{i}.attentions") |
| | if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None: |
| | assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers") |
| |
|
| | |
| | |
| | modules_from_controlnet = { |
| | "controlnet_cond_embedding": "controlnet_cond_embedding", |
| | "conv_in": "ctrl_conv_in", |
| | "control_to_base_for_conv_in": "control_to_base_for_conv_in", |
| | } |
| | optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"} |
| | for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items(): |
| | assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs) |
| |
|
| | for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items(): |
| | if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None: |
| | assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs) |
| | |
| | assert len(controlnet.down_blocks) == len(model.down_blocks) |
| | for i, d in enumerate(controlnet.down_blocks): |
| | assert_equal_weights(d.resnets, f"down_blocks.{i}.ctrl_resnets") |
| | assert_equal_weights(d.base_to_ctrl, f"down_blocks.{i}.base_to_ctrl") |
| | assert_equal_weights(d.ctrl_to_base, f"down_blocks.{i}.ctrl_to_base") |
| | if d.attentions is not None: |
| | assert_equal_weights(d.attentions, f"down_blocks.{i}.ctrl_attentions") |
| | if d.downsamplers is not None: |
| | assert_equal_weights(d.downsamplers, f"down_blocks.{i}.ctrl_downsamplers") |
| | |
| | assert_equal_weights(controlnet.mid_block.base_to_ctrl, "mid_block.base_to_ctrl") |
| | assert_equal_weights(controlnet.mid_block.midblock, "mid_block.ctrl_midblock") |
| | assert_equal_weights(controlnet.mid_block.ctrl_to_base, "mid_block.ctrl_to_base") |
| | |
| | assert len(controlnet.up_connections) == len(model.up_blocks) |
| | for i, u in enumerate(controlnet.up_connections): |
| | assert_equal_weights(u.ctrl_to_base, f"up_blocks.{i}.ctrl_to_base") |
| |
|
| | def test_freeze_unet(self): |
| | def assert_frozen(module): |
| | for p in module.parameters(): |
| | assert not p.requires_grad |
| |
|
| | def assert_unfrozen(module): |
| | for p in module.parameters(): |
| | assert p.requires_grad |
| |
|
| | init_dict, _ = self.prepare_init_args_and_inputs_for_common() |
| | model = UNetControlNetXSModel(**init_dict) |
| | model.freeze_unet_params() |
| |
|
| | |
| | |
| | modules_from_unet = [ |
| | model.base_time_embedding, |
| | model.base_conv_in, |
| | model.base_conv_norm_out, |
| | model.base_conv_out, |
| | ] |
| | for m in modules_from_unet: |
| | assert_frozen(m) |
| |
|
| | optional_modules_from_unet = [ |
| | model.base_add_time_proj, |
| | model.base_add_embedding, |
| | ] |
| | for m in optional_modules_from_unet: |
| | if m is not None: |
| | assert_frozen(m) |
| |
|
| | |
| | for i, d in enumerate(model.down_blocks): |
| | assert_frozen(d.base_resnets) |
| | if isinstance(d.base_attentions, nn.ModuleList): |
| | assert_frozen(d.base_attentions) |
| | if d.base_downsamplers is not None: |
| | assert_frozen(d.base_downsamplers) |
| |
|
| | |
| | assert_frozen(model.mid_block.base_midblock) |
| |
|
| | |
| | for i, u in enumerate(model.up_blocks): |
| | assert_frozen(u.resnets) |
| | if isinstance(u.attentions, nn.ModuleList): |
| | assert_frozen(u.attentions) |
| | if u.upsamplers is not None: |
| | assert_frozen(u.upsamplers) |
| |
|
| | |
| | |
| | modules_from_controlnet = [ |
| | model.controlnet_cond_embedding, |
| | model.ctrl_conv_in, |
| | model.control_to_base_for_conv_in, |
| | ] |
| | optional_modules_from_controlnet = [model.ctrl_time_embedding] |
| |
|
| | for m in modules_from_controlnet: |
| | assert_unfrozen(m) |
| | for m in optional_modules_from_controlnet: |
| | if m is not None: |
| | assert_unfrozen(m) |
| |
|
| | |
| | for d in model.down_blocks: |
| | assert_unfrozen(d.ctrl_resnets) |
| | assert_unfrozen(d.base_to_ctrl) |
| | assert_unfrozen(d.ctrl_to_base) |
| | if isinstance(d.ctrl_attentions, nn.ModuleList): |
| | assert_unfrozen(d.ctrl_attentions) |
| | if d.ctrl_downsamplers is not None: |
| | assert_unfrozen(d.ctrl_downsamplers) |
| | |
| | assert_unfrozen(model.mid_block.base_to_ctrl) |
| | assert_unfrozen(model.mid_block.ctrl_midblock) |
| | assert_unfrozen(model.mid_block.ctrl_to_base) |
| | |
| | for u in model.up_blocks: |
| | assert_unfrozen(u.ctrl_to_base) |
| |
|
| | def test_gradient_checkpointing_is_applied(self): |
| | model_class_copy = copy.copy(UNetControlNetXSModel) |
| |
|
| | modules_with_gc_enabled = {} |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def _set_gradient_checkpointing_new(self, module, value=False): |
| | if hasattr(module, "gradient_checkpointing"): |
| | module.gradient_checkpointing = value |
| | modules_with_gc_enabled[module.__class__.__name__] = True |
| |
|
| | model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new |
| |
|
| | init_dict, _ = self.prepare_init_args_and_inputs_for_common() |
| | model = model_class_copy(**init_dict) |
| |
|
| | model.enable_gradient_checkpointing() |
| |
|
| | EXPECTED_SET = { |
| | "Transformer2DModel", |
| | "UNetMidBlock2DCrossAttn", |
| | "ControlNetXSCrossAttnDownBlock2D", |
| | "ControlNetXSCrossAttnMidBlock2D", |
| | "ControlNetXSCrossAttnUpBlock2D", |
| | } |
| |
|
| | assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET |
| | assert all(modules_with_gc_enabled.values()), "All modules should be enabled" |
| |
|
| | @is_flaky |
| | def test_forward_no_control(self): |
| | unet = self.get_dummy_unet() |
| | controlnet = self.get_dummy_controlnet_from_unet(unet) |
| |
|
| | model = UNetControlNetXSModel.from_unet(unet, controlnet) |
| |
|
| | unet = unet.to(torch_device) |
| | model = model.to(torch_device) |
| |
|
| | input_ = self.dummy_input |
| |
|
| | control_specific_input = ["controlnet_cond", "conditioning_scale"] |
| | input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input} |
| |
|
| | with torch.no_grad(): |
| | unet_output = unet(**input_for_unet).sample.cpu() |
| | unet_controlnet_output = model(**input_, apply_control=False).sample.cpu() |
| |
|
| | assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4 |
| |
|
| | def test_time_embedding_mixing(self): |
| | unet = self.get_dummy_unet() |
| | controlnet = self.get_dummy_controlnet_from_unet(unet) |
| | controlnet_mix_time = self.get_dummy_controlnet_from_unet( |
| | unet, time_embedding_mix=0.5, learn_time_embedding=True |
| | ) |
| |
|
| | model = UNetControlNetXSModel.from_unet(unet, controlnet) |
| | model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time) |
| |
|
| | unet = unet.to(torch_device) |
| | model = model.to(torch_device) |
| | model_mix_time = model_mix_time.to(torch_device) |
| |
|
| | input_ = self.dummy_input |
| |
|
| | with torch.no_grad(): |
| | output = model(**input_).sample |
| | output_mix_time = model_mix_time(**input_).sample |
| |
|
| | assert output.shape == output_mix_time.shape |
| |
|
| | def test_forward_with_norm_groups(self): |
| | |
| | pass |
| |
|