Spaces:
Runtime error
Runtime error
| # Copyright 2023 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import datetime | |
| import torch | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.nn.modules.normalization import GroupNorm | |
| import os | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.attention_processor import AttentionProcessor | |
| from diffusers.utils import USE_PEFT_BACKEND | |
| from diffusers.models.autoencoders import AutoencoderKL | |
| from diffusers.models.lora import LoRACompatibleConv | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.models.unets.unet_2d_blocks import ( | |
| CrossAttnDownBlock2D, | |
| CrossAttnUpBlock2D, | |
| DownBlock2D, | |
| Downsample2D, | |
| ResnetBlock2D, | |
| Transformer2DModel, | |
| UpBlock2D, | |
| Upsample2D, | |
| ) | |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel | |
| from diffusers.utils import BaseOutput, logging | |
| import numpy as np | |
| from PIL import Image | |
| from safetensors import safe_open | |
| from .attention_autoencoder import AttentionAutoencoder, PositionalEncoding | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class ControlNetXSOutput(BaseOutput): | |
| """ | |
| The output of [`ControlNetXSModel`]. | |
| Args: | |
| sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): | |
| The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model | |
| output, but is already the final output. | |
| """ | |
| sample: torch.FloatTensor = None | |
| # copied from diffusers.models.controlnet.ControlNetConditioningEmbedding | |
| class ControlNetConditioningEmbedding(nn.Module): | |
| """ | |
| Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN | |
| [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized | |
| training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the | |
| convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides | |
| (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full | |
| model) to encode image-space conditions ... into feature maps ..." | |
| """ | |
| def __init__( | |
| self, | |
| conditioning_embedding_channels: int, | |
| conditioning_channels: int = 3, | |
| block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), | |
| ): | |
| super().__init__() | |
| self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) | |
| self.blocks = nn.ModuleList([]) | |
| for i in range(len(block_out_channels) - 1): | |
| channel_in = block_out_channels[i] | |
| channel_out = block_out_channels[i + 1] | |
| self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) | |
| self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) | |
| self.conv_out = zero_module( | |
| nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) | |
| ) | |
| def forward(self, conditioning): | |
| embedding = self.conv_in(conditioning) | |
| embedding = F.silu(embedding) | |
| for block in self.blocks: | |
| embedding = block(embedding) | |
| embedding = F.silu(embedding) | |
| embedding = self.conv_out(embedding) | |
| return embedding | |
| class ControlNetConditioningEmbeddingBig(nn.Module): | |
| def __init__( | |
| self, | |
| conditioning_embedding_channels: int, | |
| conditioning_channels: int = 4, | |
| block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), | |
| text_embed_dim: int = 768, | |
| ): | |
| super().__init__() | |
| self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) | |
| self.cross_attention = CrossAttention(block_out_channels[0], text_embed_dim) | |
| # Encoder with increasing feature maps and more downsampling | |
| self.encoder = nn.ModuleList([ | |
| nn.Conv2d(block_out_channels[0], 64, kernel_size=3, stride=2, padding=1), | |
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), | |
| nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), | |
| nn.Conv2d(256, 320, kernel_size=3, stride=2, padding=1), | |
| nn.Conv2d(320, 512, kernel_size=3, stride=2, padding=1), | |
| nn.Conv2d(512, 640, kernel_size=3, stride=2, padding=1), | |
| ]) | |
| # Global embedding processing | |
| self.global_fc = nn.Linear(640, 640) | |
| # Bottleneck | |
| self.bottleneck_down = nn.Conv2d(640, 6, kernel_size=3, stride=1, padding=1) | |
| self.bottleneck_up = nn.Conv2d(6, 320, kernel_size=3, stride=1, padding=1) | |
| # Smaller decoder to get back to 320x64x64 | |
| self.decoder = nn.ModuleList([ | |
| nn.ConvTranspose2d(320, 320, kernel_size=4, stride=2, padding=1), # 4x4 -> 8x8 | |
| nn.ConvTranspose2d(320, 320, kernel_size=4, stride=2, padding=1), # 8x8 -> 16x16 | |
| nn.ConvTranspose2d(320, 320, kernel_size=4, stride=2, padding=1), # 16x16 -> 32x32 | |
| ]) | |
| def forward(self, x, text_embeds): | |
| x = self.conv_in(x) | |
| x = self.cross_attention(x, text_embeds) | |
| # Encoder | |
| for encoder_layer in self.encoder: | |
| x = encoder_layer(x) | |
| x = F.relu(x) | |
| # Global embedding processing | |
| b, c, h, w = x.shape | |
| x_flat = x.view(b, c, -1).mean(dim=2) # Global average pooling | |
| x_global = self.global_fc(x_flat).view(b, c, 1, 1) | |
| x = x + x_global.expand_as(x) # Add global features to local features | |
| # Bottleneck | |
| x = self.bottleneck_down(x) | |
| x = self.bottleneck_up(x) | |
| # Decoder | |
| for decoder_layer in self.decoder: | |
| x = decoder_layer(x) | |
| x = F.relu(x) | |
| #print(x.shape) | |
| return x | |
| class CrossAttention(nn.Module): | |
| def __init__(self, dim, context_dim): | |
| super().__init__() | |
| self.to_q = nn.Conv2d(dim, dim, 1) | |
| self.to_k = nn.Linear(context_dim, dim) | |
| self.to_v = nn.Linear(context_dim, dim) | |
| self.scale = dim ** -0.5 | |
| def forward(self, x, context): | |
| b, c, h, w = x.shape | |
| q = self.to_q(x).view(b, c, -1).permute(0, 2, 1) # (B, H*W, C) | |
| k = self.to_k(context) # (B, T, C) | |
| v = self.to_v(context) # (B, T, C) | |
| attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B, H*W, T) | |
| attn = attn.softmax(dim=-1) | |
| out = torch.matmul(attn, v) # (B, H*W, C) | |
| out = out.permute(0, 2, 1).view(b, c, h, w) # (B, C, H, W) | |
| return out + x | |
| def zero_module(module): | |
| for p in module.parameters(): | |
| nn.init.zeros_(p) | |
| return module | |
| class StyleCodesModel(ModelMixin, ConfigMixin): | |
| r""" | |
| Based off ControlNet-XS | |
| """ | |
| def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True): | |
| """ | |
| Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). | |
| Parameters: | |
| base_model (`UNet2DConditionModel`): | |
| Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. | |
| is_sdxl (`bool`, defaults to `True`): | |
| Whether passed `base_model` is a StableDiffusion-XL model. | |
| """ | |
| def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int): | |
| """ | |
| Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). | |
| The original ControlNet-XS model, however, define the number of attention heads. | |
| That's why compute the dimensions needed to get the correct number of attention heads. | |
| """ | |
| block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] | |
| dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] | |
| return dim_attn_heads | |
| if is_sdxl: | |
| return StyleCodesModel.from_unet( | |
| base_model, | |
| time_embedding_mix=0.95, | |
| learn_embedding=True, | |
| size_ratio=0.1, | |
| conditioning_embedding_out_channels=(16, 32, 96, 256), | |
| num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), | |
| ) | |
| else: | |
| return StyleCodesModel.from_unet( | |
| base_model, | |
| time_embedding_mix=1.0, | |
| learn_embedding=True, | |
| size_ratio=0.0125, | |
| conditioning_embedding_out_channels=(16, 32, 96, 256), | |
| num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), | |
| ) | |
| def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str): | |
| """To create correctly sized connections between base and control model, we need to know | |
| the input and output channels of each subblock. | |
| Parameters: | |
| unet (`UNet2DConditionModel`): | |
| Unet of which the subblock channels sizes are to be gathered. | |
| base_or_control (`str`): | |
| Needs to be either "base" or "control". If "base", decoder is also considered. | |
| """ | |
| if base_or_control not in ["base", "control"]: | |
| raise ValueError("`base_or_control` needs to be either `base` or `control`") | |
| channel_sizes = {"down": [], "mid": [], "up": []} | |
| # input convolution | |
| channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels)) | |
| # encoder blocks | |
| for module in unet.down_blocks: | |
| if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): | |
| for r in module.resnets: | |
| channel_sizes["down"].append((r.in_channels, r.out_channels)) | |
| if module.downsamplers: | |
| channel_sizes["down"].append( | |
| (module.downsamplers[0].channels, module.downsamplers[0].out_channels) | |
| ) | |
| else: | |
| raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.") | |
| # middle block | |
| channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels)) | |
| # decoder blocks | |
| #if base_or_control == "base": | |
| for module in unet.up_blocks: | |
| if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)): | |
| for r in module.resnets: | |
| channel_sizes["up"].append((r.in_channels, r.out_channels)) | |
| else: | |
| raise ValueError( | |
| f"Encountered unknown module of type {type(module)} while creating ControlNet-XS." | |
| ) | |
| return channel_sizes | |
| def _make_colab_linear_layer(self, in_channels, out_channels): | |
| # Create a Linear layer where in_features = in_channels + out_channels | |
| #in_features = in_channels + out_channels | |
| linear_layer = nn.Linear(in_channels, out_channels) | |
| # Initialize weights as identity | |
| with torch.no_grad(): | |
| linear_layer.weight.copy_(torch.eye(in_channels)) | |
| return linear_layer | |
| def __init__( | |
| self, | |
| conditioning_channels: int = 3, | |
| conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), | |
| controlnet_conditioning_channel_order: str = "rgb", | |
| time_embedding_input_dim: int = 320, | |
| time_embedding_dim: int = 1280, | |
| time_embedding_mix: float = 1.0, | |
| learn_embedding: bool = False, | |
| base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { | |
| "down": [ | |
| (4, 320), | |
| (320, 320), | |
| (320, 320), | |
| (320, 320), | |
| (320, 640), | |
| (640, 640), | |
| (640, 640), | |
| (640, 1280), | |
| (1280, 1280), | |
| ], | |
| "mid": [(1280, 1280)], | |
| "up": [ | |
| (2560, 1280), | |
| (2560, 1280), | |
| (1920, 1280), | |
| (1920, 640), | |
| (1280, 640), | |
| (960, 640), | |
| (960, 320), | |
| (640, 320), | |
| (640, 320), | |
| ], | |
| }, | |
| sample_size: Optional[int] = None, | |
| down_block_types: Tuple[str] = ( | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "DownBlock2D", | |
| ), | |
| up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), | |
| block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
| norm_num_groups: Optional[int] = 32, | |
| cross_attention_dim: Union[int, Tuple[int]] = 1280, | |
| transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, | |
| num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, | |
| upcast_attention: bool = False, | |
| ): | |
| super().__init__() | |
| # 1 - Create control unet | |
| self.control_model = UNet2DConditionModel( | |
| sample_size=sample_size, | |
| down_block_types=down_block_types, | |
| up_block_types=up_block_types, | |
| block_out_channels=block_out_channels, | |
| norm_num_groups=norm_num_groups, | |
| cross_attention_dim=cross_attention_dim, | |
| transformer_layers_per_block=transformer_layers_per_block, | |
| attention_head_dim=num_attention_heads, | |
| use_linear_projection=True, | |
| upcast_attention=upcast_attention, | |
| time_embedding_dim=time_embedding_dim, | |
| ) | |
| # 2 - Do model surgery on control model | |
| # 2.1 - Allow to use the same time information as the base model | |
| adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim) | |
| # 2.2 - Allow for information infusion from base model | |
| # We concat the output of each base encoder subblocks to the input of the next control encoder subblock | |
| # (We ignore the 1st element, as it represents the `conv_in`.) | |
| extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]] | |
| it_extra_input_channels = iter(extra_input_channels) | |
| # print(extra_input_channels) | |
| # for b, block in enumerate(self.control_model.down_blocks): | |
| # for r in range(len(block.resnets)): | |
| # increase_block_input_in_encoder_resnet( | |
| # self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels) | |
| # ) | |
| # if block.downsamplers: | |
| # increase_block_input_in_encoder_downsampler( | |
| # self.control_model, block_no=b, by=next(it_extra_input_channels) | |
| # ) | |
| # increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1]) | |
| def get_flat_subblock_channel_sizes_down(model): | |
| subblock_channel_sizes = [] | |
| for block in model.down_blocks: | |
| # Iterate through ResnetBlock2D subblocks | |
| for resnet in block.resnets: | |
| # Only handle the first convolution for ResnetBlock2D | |
| if hasattr(resnet, 'conv1'): | |
| input_channels = resnet.conv1.in_channels | |
| output_channels = resnet.conv1.out_channels | |
| subblock_channel_sizes.append((input_channels, output_channels)) | |
| # Check and iterate through Upsample2D subblocks only if they exist | |
| if hasattr(block, 'upsamplers') and block.upsamplers: | |
| for upsampler in block.upsamplers: | |
| if hasattr(upsampler, 'conv'): | |
| input_channels = upsampler.conv.in_channels | |
| output_channels = upsampler.conv.out_channels | |
| subblock_channel_sizes.append((input_channels, output_channels)) | |
| print("down" ,subblock_channel_sizes) | |
| return subblock_channel_sizes | |
| def get_flat_subblock_channel_sizes(model): | |
| subblock_channel_sizes = [] | |
| for block in model.up_blocks: | |
| # Iterate through ResnetBlock2D subblocks | |
| for resnet in block.resnets: | |
| # Only handle the first convolution for ResnetBlock2D | |
| if hasattr(resnet, 'conv1'): | |
| input_channels = resnet.conv1.in_channels | |
| output_channels = resnet.conv1.out_channels | |
| subblock_channel_sizes.append((input_channels, output_channels)) | |
| # Check and iterate through Upsample2D subblocks only if they exist | |
| if hasattr(block, 'upsamplers') and block.upsamplers: | |
| for upsampler in block.upsamplers: | |
| if hasattr(upsampler, 'conv'): | |
| input_channels = upsampler.conv.in_channels | |
| output_channels = upsampler.conv.out_channels | |
| # subblock_channel_sizes.append((input_channels, output_channels)) | |
| print("up", subblock_channel_sizes) | |
| return subblock_channel_sizes | |
| get_flat_subblock_channel_sizes_down(self.control_model) | |
| # Now use this function to dynamically get the extra input channels | |
| #extra_input_channels_up = [t[1] for t in get_flat_subblock_channel_sizes(self.control_model)] | |
| #all_channels_up = get_flat_subblock_channel_sizes(self.control_model) | |
| #print(extra_input_channels_up) | |
| # it_extra_input_channels = iter(extra_input_channels_up) | |
| # #print(self.control_model.up_blocks) | |
| # for b, block in enumerate(self.control_model.up_blocks): | |
| # for r in range(len(block.resnets)): | |
| # increase_block_input_in_decoder_resnet( | |
| # self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels) | |
| # ) | |
| # print(len(block.resnets)) | |
| # # if block.upsamplers: | |
| # #increase_block_input_in_decoder_downsampler( | |
| # # self.control_model, block_no=b, by=next(it_extra_input_channels) | |
| # #) | |
| # 2.3 - Make group norms work with modified channel sizes | |
| adjust_group_norms(self.control_model) | |
| # 3 - Gather Channel Sizes | |
| self.ch_inout_ctrl = StyleCodesModel._gather_subblock_sizes(self.control_model, base_or_control="control") | |
| self.ch_inout_base = base_model_channel_sizes | |
| # 4 - Build connections between base and control model | |
| self.control_model.down_zero_convs_in = nn.ModuleList([]) | |
| self.control_model.middle_block_out = nn.ModuleList([]) | |
| #self.control_model.middle_block_in = nn.ModuleList([]) | |
| self.control_model.up_zero_convs_out = nn.ModuleList([]) | |
| #self.control_model.up_zero_convs_in = nn.ModuleList([]) | |
| #for ch_io_base in self.ch_inout_base["down"]: | |
| # for i in range(len(self.ch_inout_base["down"])): | |
| # if i < len(self.ch_inout_ctrl["down"]) - 1: | |
| # ch_io_base = self.ch_inout_base["down"][i] | |
| # self.control_model.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) | |
| #self.control_model.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) | |
| linear_shape = self.ch_inout_ctrl["mid"][-1][1] + self.ch_inout_ctrl["mid"][-1][1] | |
| self.middle_block_out = self._make_colab_linear_layer(in_channels=linear_shape, out_channels=linear_shape) | |
| #self.up_zero_convs_out.append( | |
| # self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1]) | |
| #) | |
| #skip connections i dont care about these | |
| #for i in range(1, len(self.ch_inout_ctrl["down"])): | |
| # self.up_zero_convs_out.append( | |
| # self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) | |
| # ) | |
| #up blocks for output | |
| #need to check the input sizes | |
| #need to implement the increased input size for the up blocks as done already with the down blocks | |
| base_last_out_channels = [1280,1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320,320] | |
| base_current_in_channels = [1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320,320] | |
| #JANK WARNING REMEMBER TO FIX LATER BEFORE ACTUALLY PUTTING THIS CODE ANYWHERE | |
| print(f"subblock up sizes {self.ch_inout_ctrl}") | |
| # for i in range(len(base_current_in_channels)): | |
| # self.control_model.up_zero_convs_in.append( | |
| # self._make_zero_conv(base_last_out_channels[i], base_current_in_channels[i]) | |
| # ) | |
| for i in range(len(self.ch_inout_base["up"])): | |
| #for ch_io_base in self.ch_inout_base["up"]: | |
| ch_io_base = self.ch_inout_base["up"][i] | |
| if i < len(self.ch_inout_ctrl["up"]): | |
| linear_shape = ch_io_base[1] + ch_io_base[1] | |
| self.control_model.up_zero_convs_out.append( | |
| self._make_colab_linear_layer(in_channels=linear_shape, out_channels=linear_shape) | |
| ) | |
| # for i in range(len(self.ch_inout_ctrl["up"])): | |
| # self.control_model.up_zero_convs_out.append( | |
| # self._make_zero_conv(self.ch_inout_ctrl["up"][i][1], self.ch_inout_base["up"][i][1]) | |
| # ) | |
| # 5 - Create conditioning hint embedding | |
| # self.controlnet_cond_embedding = ControlNetConditioningEmbedding( | |
| # conditioning_embedding_channels=block_out_channels[0], | |
| # block_out_channels=conditioning_embedding_out_channels, | |
| # conditioning_channels=conditioning_channels, | |
| # ) | |
| self.sref_autoencoder = AttentionAutoencoder().to(device='cuda') | |
| # In the mininal implementation setting, we only need the control model up to the mid block | |
| #del self.control_model.up_blocks | |
| del self.control_model.down_blocks | |
| del self.control_model.conv_norm_out | |
| del self.control_model.conv_out | |
| del self.control_model.time_embedding | |
| del self.control_model.conv_in | |
| def load_model(self, path: str): | |
| """Load the model from the given path. | |
| Parameters: | |
| path (`str`): | |
| Path to the model checkpoint. | |
| """ | |
| if os.path.splitext(path)[-1] == ".safetensors": | |
| state_dict = {"image_proj": {}, "ip_adapter": {}, "controlnet": {}} | |
| with safe_open(path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| if key.startswith("image_proj."): | |
| state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) | |
| elif key.startswith("ip_adapter."): | |
| state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) | |
| elif key.startswith("controlnet."): | |
| state_dict["controlnet"][key.replace("controlnet.", "")] = f.get_tensor(key) | |
| else: | |
| state_dict = torch.load(path, map_location="cpu") | |
| print("load controlnet", self.load_state_dict(state_dict["controlnet"],strict=False)) | |
| def from_unet( | |
| cls, | |
| unet: UNet2DConditionModel, | |
| conditioning_channels: int = 3, | |
| conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), | |
| controlnet_conditioning_channel_order: str = "rgb", | |
| learn_embedding: bool = False, | |
| time_embedding_mix: float = 1.0, | |
| block_out_channels: Optional[Tuple[int]] = None, | |
| size_ratio: Optional[float] = None, | |
| num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, | |
| norm_num_groups: Optional[int] = None, | |
| ): | |
| r""" | |
| Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. | |
| Parameters: | |
| unet (`UNet2DConditionModel`): | |
| The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. | |
| conditioning_channels (`int`, defaults to 3): | |
| Number of channels of conditioning input (e.g. an image) | |
| conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): | |
| The tuple of output channel for each block in the `controlnet_cond_embedding` layer. | |
| controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): | |
| The channel order of conditional image. Will convert to `rgb` if it's `bgr`. | |
| learn_embedding (`bool`, defaults to `False`): | |
| Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation | |
| of the time embeddings of the control and base model with interpolation parameter | |
| `time_embedding_mix**3`. | |
| time_embedding_mix (`float`, defaults to 1.0): | |
| Linear interpolation parameter used if `learn_embedding` is `True`. | |
| block_out_channels (`Tuple[int]`, *optional*): | |
| Down blocks output channels in control model. Either this or `size_ratio` must be given. | |
| size_ratio (float, *optional*): | |
| When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. | |
| Either this or `block_out_channels` must be given. | |
| num_attention_heads (`Union[int, Tuple[int]]`, *optional*): | |
| The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. | |
| norm_num_groups (int, *optional*, defaults to `None`): | |
| The number of groups to use for the normalization of the control unet. If `None`, | |
| `int(unet.config.norm_num_groups * size_ratio)` is taken. | |
| """ | |
| # Check input | |
| fixed_size = block_out_channels is not None | |
| relative_size = size_ratio is not None | |
| if not (fixed_size ^ relative_size): | |
| raise ValueError( | |
| "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." | |
| ) | |
| # Create model | |
| if block_out_channels is None: | |
| block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] | |
| # Check that attention heads and group norms match channel sizes | |
| # - attention heads | |
| def attn_heads_match_channel_sizes(attn_heads, channel_sizes): | |
| if isinstance(attn_heads, (tuple, list)): | |
| return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) | |
| else: | |
| return all(c % attn_heads == 0 for c in channel_sizes) | |
| num_attention_heads = num_attention_heads or unet.config.attention_head_dim | |
| if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): | |
| raise ValueError( | |
| f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." | |
| ) | |
| # - group norms | |
| def group_norms_match_channel_sizes(num_groups, channel_sizes): | |
| return all(c % num_groups == 0 for c in channel_sizes) | |
| if norm_num_groups is None: | |
| if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): | |
| norm_num_groups = unet.config.norm_num_groups | |
| else: | |
| norm_num_groups = min(block_out_channels) | |
| if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): | |
| print( | |
| f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." | |
| ) | |
| else: | |
| raise ValueError( | |
| f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." | |
| ) | |
| def get_time_emb_input_dim(unet: UNet2DConditionModel): | |
| return unet.time_embedding.linear_1.in_features | |
| def get_time_emb_dim(unet: UNet2DConditionModel): | |
| return unet.time_embedding.linear_2.out_features | |
| # Clone params from base unet if | |
| # (i) it's required to build SD or SDXL, and | |
| # (ii) it's not used for the time embedding (as time embedding of control model is never used), and | |
| # (iii) it's not set further below anyway | |
| to_keep = [ | |
| "cross_attention_dim", | |
| "down_block_types", | |
| "sample_size", | |
| "transformer_layers_per_block", | |
| "up_block_types", | |
| "upcast_attention", | |
| ] | |
| kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} | |
| kwargs.update(block_out_channels=block_out_channels) | |
| kwargs.update(num_attention_heads=num_attention_heads) | |
| kwargs.update(norm_num_groups=norm_num_groups) | |
| # Add controlnetxs-specific params | |
| kwargs.update( | |
| conditioning_channels=conditioning_channels, | |
| controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, | |
| time_embedding_input_dim=get_time_emb_input_dim(unet), | |
| time_embedding_dim=get_time_emb_dim(unet), | |
| time_embedding_mix=time_embedding_mix, | |
| learn_embedding=learn_embedding, | |
| base_model_channel_sizes=StyleCodesModel._gather_subblock_sizes(unet, base_or_control="base"), | |
| conditioning_embedding_out_channels=conditioning_embedding_out_channels, | |
| ) | |
| return cls(**kwargs) | |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
| r""" | |
| Returns: | |
| `dict` of attention processors: A dictionary containing all attention processors used in the model with | |
| indexed by its weight name. | |
| """ | |
| return self.control_model.attn_processors | |
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
| r""" | |
| Sets the attention processor to use to compute attention. | |
| Parameters: | |
| processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
| The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
| for **all** `Attention` layers. | |
| If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
| processor. This is strongly recommended when setting trainable attention processors. | |
| """ | |
| self.control_model.set_attn_processor(processor) | |
| def set_default_attn_processor(self): | |
| """ | |
| Disables custom attention processors and sets the default attention implementation. | |
| """ | |
| self.control_model.set_default_attn_processor() | |
| def set_attention_slice(self, slice_size): | |
| r""" | |
| Enable sliced attention computation. | |
| When this option is enabled, the attention module splits the input tensor in slices to compute attention in | |
| several steps. This is useful for saving some memory in exchange for a small decrease in speed. | |
| Args: | |
| slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): | |
| When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If | |
| `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is | |
| provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` | |
| must be a multiple of `slice_size`. | |
| """ | |
| self.control_model.set_attention_slice(slice_size) | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if isinstance(module, (UNet2DConditionModel)): | |
| if value: | |
| module.enable_gradient_checkpointing() | |
| else: | |
| module.disable_gradient_checkpointing() | |
| def forward( | |
| self, | |
| base_model: UNet2DConditionModel, | |
| sample: torch.FloatTensor, | |
| timestep: Union[torch.Tensor, float, int], | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_hidden_states_controlnet: torch.Tensor, | |
| controlnet_cond: torch.Tensor, | |
| conditioning_scale: float = 1.0, | |
| class_labels: Optional[torch.Tensor] = None, | |
| timestep_cond: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
| return_dict: bool = True, | |
| stylecode=None, | |
| ) -> Union[ControlNetXSOutput, Tuple]: | |
| """ | |
| The [`ControlNetModel`] forward method. | |
| Args: | |
| base_model (`UNet2DConditionModel`): | |
| The base unet model we want to control. | |
| sample (`torch.FloatTensor`): | |
| The noisy input tensor. | |
| timestep (`Union[torch.Tensor, float, int]`): | |
| The number of timesteps to denoise an input. | |
| encoder_hidden_states (`torch.Tensor`): | |
| The encoder hidden states. | |
| controlnet_cond (`torch.FloatTensor`): | |
| The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. | |
| conditioning_scale (`float`, defaults to `1.0`): | |
| How much the control model affects the base model outputs. | |
| class_labels (`torch.Tensor`, *optional*, defaults to `None`): | |
| Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
| timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): | |
| Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the | |
| timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep | |
| embeddings. | |
| attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
| An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
| is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
| negative values to the attention scores corresponding to "discard" tokens. | |
| added_cond_kwargs (`dict`): | |
| Additional conditions for the Stable Diffusion XL UNet. | |
| cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): | |
| A kwargs dictionary that if specified is passed along to the `AttnProcessor`. | |
| return_dict (`bool`, defaults to `True`): | |
| Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. | |
| Returns: | |
| [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: | |
| If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a | |
| tuple is returned where the first element is the sample tensor. | |
| """ | |
| # check channel order | |
| channel_order = self.config.controlnet_conditioning_channel_order | |
| if channel_order == "rgb": | |
| # in rgb order by default | |
| ... | |
| elif channel_order == "bgr": | |
| controlnet_cond = torch.flip(controlnet_cond, dims=[1]) | |
| else: | |
| raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") | |
| # scale control strength | |
| n_connections = 0 + 1 + len(self.control_model.up_zero_convs_out) | |
| scale_list = torch.full((n_connections,), conditioning_scale) | |
| # prepare attention_mask | |
| if attention_mask is not None: | |
| attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
| attention_mask = attention_mask.unsqueeze(1) | |
| # 1. time | |
| timesteps = timestep | |
| if not torch.is_tensor(timesteps): | |
| # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
| # This would be a good case for the `match` statement (Python 3.10+) | |
| is_mps = sample.device.type == "mps" | |
| if isinstance(timestep, float): | |
| dtype = torch.float32 if is_mps else torch.float64 | |
| else: | |
| dtype = torch.int32 if is_mps else torch.int64 | |
| timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
| elif len(timesteps.shape) == 0: | |
| timesteps = timesteps[None].to(sample.device) | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timesteps = timesteps.expand(sample.shape[0]) | |
| t_emb = base_model.time_proj(timesteps) | |
| # timesteps does not contain any weights and will always return f32 tensors | |
| # but time_embedding might actually be running in fp16. so we need to cast here. | |
| # there might be better ways to encapsulate this. | |
| t_emb = t_emb.to(dtype=sample.dtype) | |
| if self.config.learn_embedding: | |
| ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond) | |
| base_temb = base_model.time_embedding(t_emb, timestep_cond) | |
| interpolation_param = self.config.time_embedding_mix**0.3 | |
| temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) | |
| else: | |
| temb = base_model.time_embedding(t_emb) | |
| # added time & text embeddings | |
| aug_emb = None | |
| aug_emb_ctrl = None | |
| if base_model.class_embedding is not None: | |
| if class_labels is None: | |
| raise ValueError("class_labels should be provided when num_class_embeds > 0") | |
| if base_model.config.class_embed_type == "timestep": | |
| class_labels = base_model.time_proj(class_labels) | |
| class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) | |
| temb = temb + class_emb | |
| if base_model.config.addition_embed_type is not None: | |
| if base_model.config.addition_embed_type == "text": | |
| aug_emb = base_model.add_embedding(encoder_hidden_states) | |
| aug_emb_ctrl = base_model.add_embedding(encoder_hidden_states_controlnet) | |
| elif base_model.config.addition_embed_type == "text_image": | |
| raise NotImplementedError() | |
| elif base_model.config.addition_embed_type == "text_time": | |
| # SDXL - style | |
| if "text_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | |
| ) | |
| text_embeds = added_cond_kwargs.get("text_embeds") | |
| if "time_ids" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | |
| ) | |
| time_ids = added_cond_kwargs.get("time_ids") | |
| time_embeds = base_model.add_time_proj(time_ids.flatten()) | |
| time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | |
| add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | |
| add_embeds = add_embeds.to(temb.dtype) | |
| aug_emb = base_model.add_embedding(add_embeds) | |
| elif base_model.config.addition_embed_type == "image": | |
| raise NotImplementedError() | |
| elif base_model.config.addition_embed_type == "image_hint": | |
| raise NotImplementedError() | |
| temb = temb + aug_emb if aug_emb is not None else temb | |
| #temb_ctrl = torch.zeros_like(temb) | |
| temb_ctrl = temb + aug_emb_ctrl if aug_emb_ctrl is not None else temb | |
| # text embeddings | |
| #note when i have more time actually skip the cross attention layers | |
| cemb = encoder_hidden_states | |
| #cemb_ctrl = torch.zeros_like(encoder_hidden_states) | |
| cemb_ctrl = encoder_hidden_states | |
| # Preparation | |
| #print("1:cond, 2: embeddings",controlnet_cond.shape,encoder_hidden_states_controlnet.shape) | |
| #save_debug_image(controlnet_cond[0]) | |
| #guided_hint = self.controlnet_cond_embedding(controlnet_cond) | |
| #guided_hint=None | |
| h_ctrl = h_base = sample | |
| hs_base, hs_ctrl = [], [] | |
| it_up_convs_out = iter (self.control_model.up_zero_convs_out) | |
| scales = iter(scale_list) | |
| base_down_subblocks = self.to_sub_blocks(base_model.down_blocks) | |
| #ctrl_down_subblocks = self.to_sub_blocks(self.control_model.down_blocks) | |
| base_mid_subblocks = self.to_sub_blocks([base_model.mid_block]) | |
| ctrl_mid_subblocks = self.to_sub_blocks([self.control_model.mid_block]) | |
| base_up_subblocks = self.to_sub_blocks(base_model.up_blocks) | |
| ctrl_up_subblocks = self.to_sub_blocks(self.control_model.up_blocks) | |
| # Cross Control | |
| # 0 - conv in | |
| h_base = base_model.conv_in(h_base) | |
| #h_ctrl = self.control_model.conv_in(h_ctrl) | |
| #if guided_hint is not None: | |
| h_ctrl = controlnet_cond | |
| # h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base | |
| hs_base.append(h_base) | |
| #hs_ctrl.append(h_ctrl) | |
| # 1 - down | |
| for m_base in base_down_subblocks: | |
| #h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl | |
| h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock | |
| #h_ctrl = m_ctrl(h_ctrl, temb_ctrl, cemb_ctrl, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock | |
| #h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base | |
| hs_base.append(h_base) | |
| #hs_ctrl.append(h_ctrl) | |
| if stylecode is None: | |
| h_ctrl,encoded_strings = self.sref_autoencoder.forward_encoding(h_ctrl,h_base.shape[2],h_base.shape[3]) | |
| else: | |
| h_ctrl = self.sref_autoencoder.forward_from_stylecode(stylecode,h_base.shape[2],h_base.shape[3],h_base.dtype, h_base.device) | |
| # 2 - mid | |
| #h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl | |
| for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): | |
| h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock | |
| h_ctrl = m_ctrl(h_ctrl, temb_ctrl, cemb_ctrl, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock | |
| #taken from https://github.com/dvlab-research/ControlNeXt/blob/main/ControlNeXt-SD1.5/models/unet.py | |
| #mid_block_additional_residual = self.middle_block_out(h_ctrl) | |
| # mid_block_additional_residual = mid_block_out | |
| # mid_block_additional_residual=nn.functional.adaptive_avg_pool2d(mid_block_additional_residual, h_base.shape[-2:]) | |
| # mid_block_additional_residual = mid_block_additional_residual.to(h_base) | |
| # mean_latents, std_latents = torch.mean(h_base, dim=(1, 2, 3), keepdim=True), torch.std(h_base, dim=(1, 2, 3), keepdim=True) | |
| # mean_control, std_control = torch.mean(mid_block_additional_residual, dim=(1, 2, 3), keepdim=True), torch.std(mid_block_additional_residual, dim=(1, 2, 3), keepdim=True) | |
| # mid_block_additional_residual = (mid_block_additional_residual - mean_control) * (std_latents / (std_control + 1e-12)) + mean_latents | |
| # h_base = h_base + mid_block_additional_residual * next(scales) | |
| batch_size, channels, height, width = h_ctrl.shape | |
| colab_input = torch.cat([h_ctrl, h_base], dim=1).view(batch_size, channels * 2, height * width).permute(0, 2, 1) | |
| colab_output = self.middle_block_out(colab_input) | |
| sequence_len = height * width | |
| colab_output = colab_output.permute(0, 2, 1).view(batch_size, channels * 2, height, width) # Reshape back | |
| h_ctrl, h_base_output = torch.chunk(colab_output, 2, dim=1) | |
| #mix using cond scale | |
| h_base = h_base * (1 - conditioning_scale) + h_base_output * conditioning_scale | |
| #h_base = h_base + mid_block_additional_residual * next(scales) # D - add ctrl -> base | |
| # 3 - up | |
| for m_base,m_ctrl in zip(base_up_subblocks,ctrl_up_subblocks): | |
| hs_base_new = hs_base.pop() | |
| h_base_with_skip = torch.cat([h_base, hs_base_new], dim=1) # concat info from base encoder+ctrl encoder | |
| empty = torch.zeros_like(hs_base_new) | |
| h_ctrl = torch.cat([h_ctrl, empty], dim=1) # concat info from ctrl encoder + skip connections | |
| h_ctrl = m_ctrl(h_ctrl, temb_ctrl, cemb_ctrl, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock | |
| h_base = m_base(h_base_with_skip, temb, cemb, attention_mask, cross_attention_kwargs) | |
| batch_size, channels, height, width = h_ctrl.shape | |
| colab_input = torch.cat([h_ctrl, h_base], dim=1).view(batch_size, channels * 2, height * width).permute(0, 2, 1) | |
| colab_output = next(it_up_convs_out)(colab_input) | |
| colab_output = colab_output.permute(0, 2, 1).view(batch_size, channels * 2, height, width) | |
| h_ctrl, h_base_output = torch.chunk(colab_output, 2, dim=1) | |
| h_base = h_base * (1 - conditioning_scale) + h_base_output * conditioning_scale | |
| #hn_ctrl = next(it_up_convs_out)(h_ctrl) | |
| #print(hn_ctrl) | |
| #h_base = h_base + hn_ctrl * next(scales) # D - add ctrl -> base | |
| h_base = base_model.conv_norm_out(h_base) | |
| h_base = base_model.conv_act(h_base) | |
| h_base = base_model.conv_out(h_base) | |
| if not return_dict: | |
| return h_base | |
| return ControlNetXSOutput(sample=h_base) | |
| #needs new stuff to work correctly | |
| # def pre_process( | |
| # self, | |
| # base_model: UNet2DConditionModel, | |
| # sample: torch.FloatTensor, | |
| # timestep: Union[torch.Tensor, float, int], | |
| # encoder_hidden_states: torch.Tensor, | |
| # controlnet_cond: torch.Tensor, | |
| # conditioning_scale: float = 1.0, | |
| # class_labels: Optional[torch.Tensor] = None, | |
| # timestep_cond: Optional[torch.Tensor] = None, | |
| # attention_mask: Optional[torch.Tensor] = None, | |
| # cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| # added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
| # return_dict: bool = True | |
| # ): | |
| # """ | |
| # The [`ControlNetModel`] forward method. | |
| # Args: | |
| # base_model (`UNet2DConditionModel`): | |
| # The base unet model we want to control. | |
| # sample (`torch.FloatTensor`): | |
| # The noisy input tensor. | |
| # timestep (`Union[torch.Tensor, float, int]`): | |
| # The number of timesteps to denoise an input. | |
| # encoder_hidden_states (`torch.Tensor`): | |
| # The encoder hidden states. | |
| # controlnet_cond (`torch.FloatTensor`): | |
| # The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. | |
| # conditioning_scale (`float`, defaults to `1.0`): | |
| # How much the control model affects the base model outputs. | |
| # class_labels (`torch.Tensor`, *optional*, defaults to `None`): | |
| # Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
| # timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): | |
| # Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the | |
| # timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep | |
| # embeddings. | |
| # attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
| # An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
| # is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
| # negative values to the attention scores corresponding to "discard" tokens. | |
| # added_cond_kwargs (`dict`): | |
| # Additional conditions for the Stable Diffusion XL UNet. | |
| # cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): | |
| # A kwargs dictionary that if specified is passed along to the `AttnProcessor`. | |
| # return_dict (`bool`, defaults to `True`): | |
| # Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. | |
| # Returns: | |
| # [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: | |
| # If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a | |
| # tuple is returned where the first element is the sample tensor. | |
| # """ | |
| # # check channel order | |
| # channel_order = self.config.controlnet_conditioning_channel_order | |
| # if channel_order == "rgb": | |
| # # in rgb order by default | |
| # ... | |
| # elif channel_order == "bgr": | |
| # controlnet_cond = torch.flip(controlnet_cond, dims=[1]) | |
| # else: | |
| # raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") | |
| # # scale control strength | |
| # n_connections = len(self.control_model.down_zero_convs_out) + 1 + len(self.control_model.up_zero_convs_out) | |
| # scale_list = torch.full((n_connections,), conditioning_scale) | |
| # # prepare attention_mask | |
| # if attention_mask is not None: | |
| # attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
| # attention_mask = attention_mask.unsqueeze(1) | |
| # # 1. time | |
| # timesteps = timestep | |
| # if not torch.is_tensor(timesteps): | |
| # # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
| # # This would be a good case for the `match` statement (Python 3.10+) | |
| # is_mps = sample.device.type == "mps" | |
| # if isinstance(timestep, float): | |
| # dtype = torch.float32 if is_mps else torch.float64 | |
| # else: | |
| # dtype = torch.int32 if is_mps else torch.int64 | |
| # timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
| # elif len(timesteps.shape) == 0: | |
| # timesteps = timesteps[None].to(sample.device) | |
| # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| # timesteps = timesteps.expand(sample.shape[0]) | |
| # t_emb = base_model.time_proj(timesteps) | |
| # # timesteps does not contain any weights and will always return f32 tensors | |
| # # but time_embedding might actually be running in fp16. so we need to cast here. | |
| # # there might be better ways to encapsulate this. | |
| # t_emb = t_emb.to(dtype=sample.dtype) | |
| # if self.config.learn_embedding: | |
| # ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond) | |
| # base_temb = base_model.time_embedding(t_emb, timestep_cond) | |
| # interpolation_param = self.config.time_embedding_mix**0.3 | |
| # temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) | |
| # else: | |
| # temb = base_model.time_embedding(t_emb) | |
| # # added time & text embeddings | |
| # aug_emb = None | |
| # if base_model.class_embedding is not None: | |
| # if class_labels is None: | |
| # raise ValueError("class_labels should be provided when num_class_embeds > 0") | |
| # if base_model.config.class_embed_type == "timestep": | |
| # class_labels = base_model.time_proj(class_labels) | |
| # class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) | |
| # temb = temb + class_emb | |
| # if base_model.config.addition_embed_type is not None: | |
| # if base_model.config.addition_embed_type == "text": | |
| # aug_emb = base_model.add_embedding(encoder_hidden_states) | |
| # elif base_model.config.addition_embed_type == "text_image": | |
| # raise NotImplementedError() | |
| # elif base_model.config.addition_embed_type == "text_time": | |
| # # SDXL - style | |
| # if "text_embeds" not in added_cond_kwargs: | |
| # raise ValueError( | |
| # f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | |
| # ) | |
| # text_embeds = added_cond_kwargs.get("text_embeds") | |
| # if "time_ids" not in added_cond_kwargs: | |
| # raise ValueError( | |
| # f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | |
| # ) | |
| # time_ids = added_cond_kwargs.get("time_ids") | |
| # time_embeds = base_model.add_time_proj(time_ids.flatten()) | |
| # time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | |
| # add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | |
| # add_embeds = add_embeds.to(temb.dtype) | |
| # aug_emb = base_model.add_embedding(add_embeds) | |
| # elif base_model.config.addition_embed_type == "image": | |
| # raise NotImplementedError() | |
| # elif base_model.config.addition_embed_type == "image_hint": | |
| # raise NotImplementedError() | |
| # temb = temb + aug_emb if aug_emb is not None else temb | |
| # # text embeddings | |
| # cemb = encoder_hidden_states | |
| # # Preparation | |
| # guided_hint = self.controlnet_cond_embedding(controlnet_cond) | |
| # #guided_hint=None | |
| # # h_ctrl = h_base = sample | |
| # # hs_base, hs_ctrl = [], [] | |
| # # it_down_convs_in, it_down_convs_out, it_up_convs_in, it_up_convs_out = map( | |
| # # iter, (self.control_model.down_zero_convs_in, self.control_model.down_zero_convs_out, self.control_model.up_zero_convs_in, self.control_model.up_zero_convs_out) | |
| # # ) | |
| # scales = iter(scale_list) | |
| # return temb,cemb,scales,guided_hint | |
| def _make_zero_conv(self, in_channels, out_channels=None): | |
| # keep running track of channels sizes | |
| #self.in_channels = in_channels | |
| #self.out_channels = out_channels or in_channels | |
| # | |
| return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) | |
| def _make_identity_conv(self, in_channels, out_channels=None): | |
| #out_channels = out_channels or in_channels | |
| return nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False) | |
| def _check_if_vae_compatible(self, vae: AutoencoderKL): | |
| condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1) | |
| vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) | |
| compatible = condition_downscale_factor == vae_downscale_factor | |
| return compatible, condition_downscale_factor, vae_downscale_factor | |
| def to_sub_blocks(self,blocks): | |
| if not is_iterable(blocks): | |
| blocks = [blocks] | |
| sub_blocks = [] | |
| for b in blocks: | |
| if hasattr(b, "resnets"): | |
| if hasattr(b, "attentions") and b.attentions is not None: | |
| for r, a in zip(b.resnets, b.attentions): | |
| sub_blocks.append([r, a]) | |
| num_resnets = len(b.resnets) | |
| num_attns = len(b.attentions) | |
| if num_resnets > num_attns: | |
| # we can have more resnets than attentions, so add each resnet as separate subblock | |
| for i in range(num_attns, num_resnets): | |
| sub_blocks.append([b.resnets[i]]) | |
| else: | |
| for r in b.resnets: | |
| sub_blocks.append([r]) | |
| # upsamplers are part of the same subblock | |
| if hasattr(b, "upsamplers") and b.upsamplers is not None: | |
| for u in b.upsamplers: | |
| sub_blocks[-1].extend([u]) | |
| # downsamplers are own subblock | |
| if hasattr(b, "downsamplers") and b.downsamplers is not None: | |
| for d in b.downsamplers: | |
| sub_blocks.append([d]) | |
| return list(map(SubBlock, sub_blocks)) | |
| class SubBlock(nn.ModuleList): | |
| """A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively. | |
| Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base. | |
| """ | |
| def __init__(self, ms, *args, **kwargs): | |
| if not is_iterable(ms): | |
| ms = [ms] | |
| super().__init__(ms, *args, **kwargs) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| temb: torch.Tensor, | |
| cemb: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| ): | |
| """Iterate through children and pass correct information to each.""" | |
| for m in self: | |
| if isinstance(m, ResnetBlock2D): | |
| x = m(x, temb) | |
| elif isinstance(m, Transformer2DModel): | |
| x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample | |
| elif isinstance(m, Downsample2D): | |
| x = m(x) | |
| elif isinstance(m, Upsample2D): | |
| x = m(x) | |
| else: | |
| raise ValueError( | |
| f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`" | |
| ) | |
| return x | |
| def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int): | |
| unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim) | |
| def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by): | |
| """Increase channels sizes to allow for additional concatted information from base model""" | |
| r = unet.down_blocks[block_no].resnets[resnet_idx] | |
| old_norm1, old_conv1 = r.norm1, r.conv1 | |
| # norm | |
| norm_args = "num_groups num_channels eps affine".split(" ") | |
| for a in norm_args: | |
| assert hasattr(old_norm1, a) | |
| norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} | |
| norm_kwargs["num_channels"] += by # surgery done here | |
| # conv1 | |
| conv1_args = [ | |
| "in_channels", | |
| "out_channels", | |
| "kernel_size", | |
| "stride", | |
| "padding", | |
| "dilation", | |
| "groups", | |
| "bias", | |
| "padding_mode", | |
| ] | |
| #if not USE_PEFT_BACKEND: | |
| # conv1_args.append("lora_layer") | |
| for a in conv1_args: | |
| assert hasattr(old_conv1, a) | |
| conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} | |
| conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. | |
| conv1_kwargs["in_channels"] += by # surgery done here | |
| # conv_shortcut | |
| # as we changed the input size of the block, the input and output sizes are likely different, | |
| # therefore we need a conv_shortcut (simply adding won't work) | |
| conv_shortcut_args_kwargs = { | |
| "in_channels": conv1_kwargs["in_channels"], | |
| "out_channels": conv1_kwargs["out_channels"], | |
| # default arguments from resnet.__init__ | |
| "kernel_size": 1, | |
| "stride": 1, | |
| "padding": 0, | |
| "bias": True, | |
| } | |
| # swap old with new modules | |
| unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs) | |
| unet.down_blocks[block_no].resnets[resnet_idx].conv1 = ( | |
| nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) | |
| ) | |
| unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = ( | |
| nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) | |
| ) | |
| print(f"increasing down {unet.down_blocks[block_no].resnets[resnet_idx].in_channels} by {by}") | |
| unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here | |
| def increase_block_input_in_decoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by): | |
| """Increase channels sizes to allow for additional concatted information from base model""" | |
| r = unet.up_blocks[block_no].resnets[resnet_idx] | |
| old_norm1, old_conv1 = r.norm1, r.conv1 | |
| # norm | |
| norm_args = "num_groups num_channels eps affine".split(" ") | |
| for a in norm_args: | |
| assert hasattr(old_norm1, a) | |
| norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} | |
| norm_kwargs["num_channels"] += by # surgery done here | |
| # conv1 | |
| conv1_args = [ | |
| "in_channels", | |
| "out_channels", | |
| "kernel_size", | |
| "stride", | |
| "padding", | |
| "dilation", | |
| "groups", | |
| "bias", | |
| "padding_mode", | |
| ] | |
| #if not USE_PEFT_BACKEND: | |
| # conv1_args.append("lora_layer") | |
| for a in conv1_args: | |
| assert hasattr(old_conv1, a) | |
| conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} | |
| conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. | |
| conv1_kwargs["in_channels"] += by # surgery done here | |
| # conv_shortcut | |
| # as we changed the input size of the block, the input and output sizes are likely different, | |
| # therefore we need a conv_shortcut (simply adding won't work) | |
| conv_shortcut_args_kwargs = { | |
| "in_channels": conv1_kwargs["in_channels"], | |
| "out_channels": conv1_kwargs["out_channels"], | |
| # default arguments from resnet.__init__ | |
| "kernel_size": 1, | |
| "stride": 1, | |
| "padding": 0, | |
| "bias": True, | |
| } | |
| # swap old with new modules | |
| unet.up_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs) | |
| unet.up_blocks[block_no].resnets[resnet_idx].conv1 = ( | |
| nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) | |
| ) | |
| unet.up_blocks[block_no].resnets[resnet_idx].conv_shortcut = ( | |
| nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) | |
| ) | |
| #by =unet.up_blocks[block_no].resnets[resnet_idx].out_channels | |
| print(f"increasing up {unet.up_blocks[block_no].resnets[resnet_idx].in_channels} by {by}") | |
| unet.up_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here | |
| def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by): | |
| """Increase channels sizes to allow for additional concatted information from base model""" | |
| old_down = unet.down_blocks[block_no].downsamplers[0].conv | |
| args = [ | |
| "in_channels", | |
| "out_channels", | |
| "kernel_size", | |
| "stride", | |
| "padding", | |
| "dilation", | |
| "groups", | |
| "bias", | |
| "padding_mode", | |
| ] | |
| #if not USE_PEFT_BACKEND: | |
| # args.append("lora_layer") | |
| for a in args: | |
| assert hasattr(old_down, a) | |
| kwargs = {a: getattr(old_down, a) for a in args} | |
| kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor. | |
| kwargs["in_channels"] += by # surgery done here | |
| # swap old with new modules | |
| unet.down_blocks[block_no].downsamplers[0].conv = ( | |
| nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs) | |
| ) | |
| unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here | |
| def increase_block_input_in_decoder_downsampler(unet: UNet2DConditionModel, block_no, by): | |
| """Increase channels sizes to allow for additional concatted information from base model""" | |
| old_down = unet.up_blocks[block_no].upsamplers[0].conv | |
| args = [ | |
| "in_channels", | |
| "out_channels", | |
| "kernel_size", | |
| "stride", | |
| "padding", | |
| "dilation", | |
| "groups", | |
| "bias", | |
| "padding_mode", | |
| ] | |
| if not USE_PEFT_BACKEND: | |
| args.append("lora_layer") | |
| for a in args: | |
| assert hasattr(old_down, a) | |
| kwargs = {a: getattr(old_down, a) for a in args} | |
| kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor. | |
| kwargs["in_channels"] += by # surgery done here | |
| # swap old with new modules | |
| unet.up_blocks[block_no].upsamplers[0].conv = ( | |
| nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs) | |
| ) | |
| unet.up_blocks[block_no].upsamplers[0].channels += by # surgery done here | |
| def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): | |
| """Increase channels sizes to allow for additional concatted information from base model""" | |
| m = unet.mid_block.resnets[0] | |
| old_norm1, old_conv1 = m.norm1, m.conv1 | |
| # norm | |
| norm_args = "num_groups num_channels eps affine".split(" ") | |
| for a in norm_args: | |
| assert hasattr(old_norm1, a) | |
| norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} | |
| norm_kwargs["num_channels"] += by # surgery done here | |
| conv1_args = [ | |
| "in_channels", | |
| "out_channels", | |
| "kernel_size", | |
| "stride", | |
| "padding", | |
| "dilation", | |
| "groups", | |
| "bias", | |
| "padding_mode", | |
| ] | |
| #if not USE_PEFT_BACKEND: | |
| # conv1_args.append("lora_layer") | |
| conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} | |
| conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. | |
| conv1_kwargs["in_channels"] += by # surgery done here | |
| # conv_shortcut | |
| # as we changed the input size of the block, the input and output sizes are likely different, | |
| # therefore we need a conv_shortcut (simply adding won't work) | |
| conv_shortcut_args_kwargs = { | |
| "in_channels": conv1_kwargs["in_channels"], | |
| "out_channels": conv1_kwargs["out_channels"], | |
| # default arguments from resnet.__init__ | |
| "kernel_size": 1, | |
| "stride": 1, | |
| "padding": 0, | |
| "bias": True, | |
| } | |
| # swap old with new modules | |
| unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) | |
| unet.mid_block.resnets[0].conv1 = ( | |
| nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) | |
| ) | |
| unet.mid_block.resnets[0].conv_shortcut = ( | |
| nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) | |
| ) | |
| unet.mid_block.resnets[0].in_channels += by # surgery done here | |
| def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32): | |
| def find_denominator(number, start): | |
| if start >= number: | |
| return number | |
| while start != 0: | |
| residual = number % start | |
| if residual == 0: | |
| return start | |
| start -= 1 | |
| for block in [*unet.down_blocks, unet.mid_block]: | |
| # resnets | |
| for r in block.resnets: | |
| if r.norm1.num_groups < max_num_group: | |
| r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group) | |
| if r.norm2.num_groups < max_num_group: | |
| r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group) | |
| # transformers | |
| if hasattr(block, "attentions"): | |
| for a in block.attentions: | |
| if a.norm.num_groups < max_num_group: | |
| a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) | |
| def is_iterable(o): | |
| if isinstance(o, str): | |
| return False | |
| try: | |
| iter(o) | |
| return True | |
| except TypeError: | |
| return False | |
| def save_debug_image(image, folder='debug_images', noise_threshold=0.1): | |
| os.makedirs(folder, exist_ok=True) | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| filename = f"debug_image_{timestamp}.png" | |
| filepath = os.path.join(folder, filename) | |
| print("Debugging image information:") | |
| print(f"Type of image: {type(image)}") | |
| if isinstance(image, torch.Tensor): | |
| print(f"Image tensor shape: {image.shape}") | |
| print(f"Image tensor dtype: {image.dtype}") | |
| print(f"Image tensor device: {image.device}") | |
| print(f"Image tensor min: {image.min()}, max: {image.max()}") | |
| image_np = image.cpu().detach().numpy() | |
| elif isinstance(image, np.ndarray): | |
| image_np = image | |
| else: | |
| print(f"Unexpected image type: {type(image)}") | |
| return | |
| print(f"Numpy array shape: {image_np.shape}") | |
| print(f"Numpy array dtype: {image_np.dtype}") | |
| print(f"Numpy array min: {image_np.min()}, max: {image_np.max()}") | |
| # Handle different array shapes | |
| if image_np.ndim == 4: | |
| image_np = np.squeeze(image_np, axis=0) | |
| image_np = np.transpose(image_np, (1, 2, 0)) | |
| elif image_np.ndim == 3: | |
| if image_np.shape[0] in [1, 3, 4]: | |
| image_np = np.transpose(image_np, (1, 2, 0)) | |
| elif image_np.ndim == 2: | |
| image_np = np.expand_dims(image_np, axis=-1) | |
| print(f"Processed numpy array shape: {image_np.shape}") | |
| # Normalize the image, accounting for noise | |
| if image_np.dtype != np.uint8: | |
| if image_np.max() <= 1 + noise_threshold: | |
| # Assume the image is in [0, 1] range with some noise | |
| image_np = np.clip(image_np, 0, 1) | |
| image_np = (image_np * 255).astype(np.uint8) | |
| else: | |
| # Assume the image is in a wider range, possibly due to noise | |
| lower_percentile = np.percentile(image_np, 1) | |
| upper_percentile = np.percentile(image_np, 99) | |
| image_np = np.clip(image_np, lower_percentile, upper_percentile) | |
| image_np = ((image_np - lower_percentile) / (upper_percentile - lower_percentile) * 255).astype(np.uint8) | |
| print(f"Normalized array min: {image_np.min()}, max: {image_np.max()}") | |
| try: | |
| image_pil = Image.fromarray(image_np.squeeze() if image_np.shape[-1] == 1 else image_np) | |
| image_pil.save(filepath) | |
| print(f"Debug image saved as '{filepath}'") | |
| except Exception as e: | |
| print(f"Error saving image: {str(e)}") | |
| print("Attempting to save as numpy array...") | |
| np_filepath = filepath.replace('.png', '.npy') | |
| np.save(np_filepath, image_np) | |
| print(f"Numpy array saved as '{np_filepath}'") | |
| def zero_module(module): | |
| for p in module.parameters(): | |
| nn.init.zeros_(p) | |
| return module | |