| import math
|
| from typing import Any, Dict, Optional, Tuple, Union
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| from torch import nn
|
| import math
|
| from diffusers import UNet2DConditionModel
|
| from diffusers.models.embeddings import Timesteps, TimestepEmbedding
|
| from diffusers.models.unets.unet_2d_blocks import (
|
| get_down_block,
|
| get_up_block,
|
| get_mid_block,
|
| )
|
| from diffusers.models.activations import get_activation
|
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
| from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
|
|
|
|
|
| def custom_prepare_attention_mask(
|
| self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
| ) -> torch.Tensor:
|
| r"""
|
| Prepare the attention mask for the attention computation.
|
|
|
| Args:
|
| attention_mask (`torch.Tensor`):
|
| The attention mask to prepare.
|
| target_length (`int`):
|
| The target length of the attention mask. This is the length of the attention mask after padding.
|
| batch_size (`int`):
|
| The batch size, which is used to repeat the attention mask.
|
| out_dim (`int`, *optional*, defaults to `3`):
|
| The output dimension of the attention mask. Can be either `3` or `4`.
|
|
|
| Returns:
|
| `torch.Tensor`: The prepared attention mask.
|
| """
|
| head_size = self.heads
|
| if attention_mask is None:
|
| return attention_mask
|
|
|
| current_length: int = attention_mask.shape[-1]
|
| if current_length != target_length:
|
| if attention_mask.device.type == "mps":
|
|
|
|
|
| padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
| padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
| attention_mask = torch.cat([attention_mask, padding], dim=2)
|
| else:
|
|
|
|
|
|
|
|
|
| B = attention_mask.shape[0]
|
| current_size = int(math.sqrt(current_length))
|
| target_size = int(math.sqrt(target_length))
|
| assert current_size**2 == current_length, f"current_length ({current_length}) cannot be squared to an integer size"
|
| assert target_size**2 == target_length, f"target_length ({target_length}) cannot be squared to an integer size"
|
| attention_mask = attention_mask.view(B, -1, current_size, current_size)
|
| attention_mask = F.interpolate(attention_mask, size=(target_size, target_size), mode="nearest")
|
| attention_mask = attention_mask.view(B, 1, target_length)
|
|
|
| if out_dim == 3:
|
| if attention_mask.shape[0] < batch_size * head_size:
|
| attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
| elif out_dim == 4:
|
| attention_mask = attention_mask.unsqueeze(1)
|
| attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
|
|
| return attention_mask
|
|
|
|
|
| def custom_get_attention_scores(self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
| r"""
|
| Compute the attention scores.
|
|
|
| Args:
|
| query (`torch.Tensor`): The query tensor.
|
| key (`torch.Tensor`): The key tensor.
|
| attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
|
|
| Returns:
|
| `torch.Tensor`: The attention probabilities/scores.
|
| """
|
| dtype = query.dtype
|
| if self.upcast_attention:
|
| query = query.float()
|
| key = key.float()
|
|
|
|
|
| if attention_mask is not None:
|
| baddbmm_input = attention_mask
|
| beta = 1
|
| else:
|
| baddbmm_input = torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)
|
| beta = 0
|
|
|
| attention_scores = torch.baddbmm(
|
| baddbmm_input,
|
| query,
|
| key.transpose(-1, -2),
|
| beta=beta,
|
| alpha=self.scale,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| del baddbmm_input
|
|
|
| if self.upcast_softmax:
|
| attention_scores = attention_scores.float()
|
|
|
| attention_probs = attention_scores.softmax(dim=-1)
|
| del attention_scores
|
|
|
| attention_probs = attention_probs.to(dtype)
|
|
|
| return attention_probs
|
|
|
|
|
| class CustomUNet(UNet2DConditionModel):
|
| def __init__(
|
| self,
|
| sample_size: Optional[int] = None,
|
| in_channels: int = 4,
|
| out_channels: int = 4,
|
| flip_sin_to_cos: bool = True,
|
| freq_shift: int = 0,
|
| down_block_types: Tuple[str] = (
|
| "CrossAttnDownBlock2D",
|
| "CrossAttnDownBlock2D",
|
| "CrossAttnDownBlock2D",
|
| "DownBlock2D",
|
| ),
|
| mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
| up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
| only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| layers_per_block: Union[int, Tuple[int]] = 2,
|
| downsample_padding: int = 1,
|
| mid_block_scale_factor: float = 1,
|
| dropout: float = 0.0,
|
| act_fn: str = "silu",
|
| norm_num_groups: Optional[int] = 32,
|
| norm_eps: float = 1e-5,
|
| cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
| transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
| reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
| attention_head_dim: Union[int, Tuple[int]] = 8,
|
| num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
| dual_cross_attention: bool = False,
|
| use_linear_projection: bool = False,
|
| upcast_attention: bool = False,
|
| resnet_time_scale_shift: str = "default",
|
| resnet_skip_time_act: bool = False,
|
| resnet_out_scale_factor: int = 1.0,
|
| time_embedding_dim: Optional[int] = None,
|
| timestep_post_act: Optional[str] = None,
|
| time_cond_proj_dim: Optional[int] = None,
|
| conv_in_kernel: int = 3,
|
| conv_out_kernel: int = 3,
|
| bbox_time_embed_dim: Optional[int] = None,
|
| point_embeddings_input_dim: Optional[int] = None,
|
| bbox_embeddings_input_dim: Optional[int] = None,
|
| attention_type: str = "default",
|
| class_embeddings_concat: bool = False,
|
| mid_block_only_cross_attention: Optional[bool] = None,
|
| cross_attention_norm: Optional[str] = None,
|
| use_attention_mask_list=[True, True, True],
|
| use_encoder_hidden_states_list=[True, True, True],
|
| ):
|
| super().__init__()
|
| self.use_attention_mask_list = use_attention_mask_list
|
| self.use_encoder_hidden_states_list = use_encoder_hidden_states_list
|
| self.sample_size = sample_size
|
| num_attention_heads = num_attention_heads or attention_head_dim
|
|
|
|
|
| conv_in_padding = (conv_in_kernel - 1) // 2
|
| self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
|
|
|
|
|
| time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
| self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| timestep_input_dim = block_out_channels[0]
|
| self.time_embedding = TimestepEmbedding(
|
| timestep_input_dim,
|
| time_embed_dim,
|
| act_fn=act_fn,
|
| post_act_fn=timestep_post_act,
|
| cond_proj_dim=time_cond_proj_dim,
|
| )
|
|
|
| self.point_embedding = TimestepEmbedding(point_embeddings_input_dim, time_embed_dim)
|
| self.bbox_time_proj = Timesteps(bbox_time_embed_dim, flip_sin_to_cos, freq_shift)
|
| self.bbox_embedding = TimestepEmbedding(bbox_embeddings_input_dim, time_embed_dim)
|
|
|
| self.down_blocks = nn.ModuleList([])
|
| self.up_blocks = nn.ModuleList([])
|
| if isinstance(only_cross_attention, bool):
|
| if mid_block_only_cross_attention is None:
|
| mid_block_only_cross_attention = only_cross_attention
|
| only_cross_attention = [only_cross_attention] * len(down_block_types)
|
|
|
| if mid_block_only_cross_attention is None:
|
| mid_block_only_cross_attention = False
|
|
|
| if isinstance(num_attention_heads, int):
|
| num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
|
|
| if isinstance(attention_head_dim, int):
|
| attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
|
|
| if isinstance(cross_attention_dim, int):
|
| cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
|
|
| if isinstance(layers_per_block, int):
|
| layers_per_block = [layers_per_block] * len(down_block_types)
|
|
|
| if isinstance(transformer_layers_per_block, int):
|
| transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
|
|
| if class_embeddings_concat:
|
| blocks_time_embed_dim = time_embed_dim * 2
|
| else:
|
| blocks_time_embed_dim = time_embed_dim
|
|
|
|
|
| output_channel = block_out_channels[0]
|
| for i, down_block_type in enumerate(down_block_types):
|
| input_channel = output_channel
|
| output_channel = block_out_channels[i]
|
| is_final_block = i == len(block_out_channels) - 1
|
|
|
| down_block = get_down_block(
|
| down_block_type,
|
| num_layers=layers_per_block[i],
|
| transformer_layers_per_block=transformer_layers_per_block[i],
|
| in_channels=input_channel,
|
| out_channels=output_channel,
|
| temb_channels=blocks_time_embed_dim,
|
| add_downsample=not is_final_block,
|
| resnet_eps=norm_eps,
|
| resnet_act_fn=act_fn,
|
| resnet_groups=norm_num_groups,
|
| cross_attention_dim=cross_attention_dim[i],
|
| num_attention_heads=num_attention_heads[i],
|
| downsample_padding=downsample_padding,
|
| dual_cross_attention=dual_cross_attention,
|
| use_linear_projection=use_linear_projection,
|
| only_cross_attention=only_cross_attention[i],
|
| upcast_attention=upcast_attention,
|
| resnet_time_scale_shift=resnet_time_scale_shift,
|
| attention_type=attention_type,
|
| resnet_skip_time_act=resnet_skip_time_act,
|
| resnet_out_scale_factor=resnet_out_scale_factor,
|
| cross_attention_norm=cross_attention_norm,
|
| attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| dropout=dropout,
|
| )
|
| self.down_blocks.append(down_block)
|
|
|
|
|
| self.mid_block = get_mid_block(
|
| mid_block_type,
|
| temb_channels=blocks_time_embed_dim,
|
| in_channels=block_out_channels[-1],
|
| resnet_eps=norm_eps,
|
| resnet_act_fn=act_fn,
|
| resnet_groups=norm_num_groups,
|
| output_scale_factor=mid_block_scale_factor,
|
| transformer_layers_per_block=transformer_layers_per_block[-1],
|
| num_attention_heads=num_attention_heads[-1],
|
| cross_attention_dim=cross_attention_dim[-1],
|
| dual_cross_attention=dual_cross_attention,
|
| use_linear_projection=use_linear_projection,
|
| mid_block_only_cross_attention=mid_block_only_cross_attention,
|
| upcast_attention=upcast_attention,
|
| resnet_time_scale_shift=resnet_time_scale_shift,
|
| attention_type=attention_type,
|
| resnet_skip_time_act=resnet_skip_time_act,
|
| cross_attention_norm=cross_attention_norm,
|
| attention_head_dim=attention_head_dim[-1],
|
| dropout=dropout,
|
| )
|
|
|
|
|
| self.num_upsamplers = 0
|
|
|
|
|
| reversed_block_out_channels = list(reversed(block_out_channels))
|
| reversed_num_attention_heads = list(reversed(num_attention_heads))
|
| reversed_layers_per_block = list(reversed(layers_per_block))
|
| reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
| reversed_transformer_layers_per_block = (
|
| list(reversed(transformer_layers_per_block))
|
| if reverse_transformer_layers_per_block is None
|
| else reverse_transformer_layers_per_block
|
| )
|
| only_cross_attention = list(reversed(only_cross_attention))
|
|
|
| output_channel = reversed_block_out_channels[0]
|
| for i, up_block_type in enumerate(up_block_types):
|
| is_final_block = i == len(block_out_channels) - 1
|
|
|
| prev_output_channel = output_channel
|
| output_channel = reversed_block_out_channels[i]
|
| input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
|
|
|
|
| if not is_final_block:
|
| add_upsample = True
|
| self.num_upsamplers += 1
|
| else:
|
| add_upsample = False
|
|
|
| up_block = get_up_block(
|
| up_block_type,
|
| num_layers=reversed_layers_per_block[i] + 1,
|
| transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
| in_channels=input_channel,
|
| out_channels=output_channel,
|
| prev_output_channel=prev_output_channel,
|
| temb_channels=blocks_time_embed_dim,
|
| add_upsample=add_upsample,
|
| resnet_eps=norm_eps,
|
| resnet_act_fn=act_fn,
|
| resolution_idx=i,
|
| resnet_groups=norm_num_groups,
|
| cross_attention_dim=reversed_cross_attention_dim[i],
|
| num_attention_heads=reversed_num_attention_heads[i],
|
| dual_cross_attention=dual_cross_attention,
|
| use_linear_projection=use_linear_projection,
|
| only_cross_attention=only_cross_attention[i],
|
| upcast_attention=upcast_attention,
|
| resnet_time_scale_shift=resnet_time_scale_shift,
|
| attention_type=attention_type,
|
| resnet_skip_time_act=resnet_skip_time_act,
|
| resnet_out_scale_factor=resnet_out_scale_factor,
|
| cross_attention_norm=cross_attention_norm,
|
| attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| dropout=dropout,
|
| )
|
| self.up_blocks.append(up_block)
|
| prev_output_channel = output_channel
|
|
|
|
|
| if norm_num_groups is not None:
|
| self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
|
|
| self.conv_act = get_activation(act_fn)
|
|
|
| else:
|
| self.conv_norm_out = None
|
| self.conv_act = None
|
|
|
| conv_out_padding = (conv_out_kernel - 1) // 2
|
| self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
|
|
|
|
|
| self.feature_map = []
|
|
|
| def _get_value(self, use_list, true_value, false_value):
|
| down_value = mid_value = up_value = false_value
|
|
|
| if use_list[0]:
|
| down_value = true_value
|
| if use_list[1]:
|
| mid_value = true_value
|
| if use_list[2]:
|
| up_value = true_value
|
|
|
| return down_value, mid_value, up_value
|
|
|
| def forward(
|
| self,
|
| sample: torch.FloatTensor,
|
| timestep: Union[torch.Tensor, float, int],
|
| trans: Union[torch.Tensor, float, int],
|
| encoder_hidden_states: torch.Tensor,
|
| encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
| timestep_cond: Optional[torch.Tensor] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| encoder_attention_mask: Optional[torch.Tensor] = None,
|
| ) -> Union[UNet2DConditionOutput, Tuple]:
|
| default_overall_up_factor = 2**self.num_upsamplers
|
| forward_upsample_size = False
|
| upsample_size = None
|
|
|
| for dim in sample.shape[-2:]:
|
| if dim % default_overall_up_factor != 0:
|
| forward_upsample_size = True
|
| break
|
|
|
| if attention_mask is not None:
|
| attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| attention_mask = attention_mask.unsqueeze(1)
|
|
|
| if encoder_attention_mask is not None:
|
| encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
|
|
|
|
| if self.config.center_input_sample:
|
| sample = 2 * sample - 1.0
|
|
|
| down_attn_mask, mid_attn_mask, up_attn_mask = self._get_value(self.use_attention_mask_list, attention_mask, None)
|
| down_encoder_hidden_states, mid_encoder_hidden_states, up_encoder_hidden_states = self._get_value(
|
| self.use_encoder_hidden_states_list, encoder_hidden_states, encoder_hidden_states_2
|
| )
|
|
|
|
|
| t_emb, op_emb, aug_emb = None, None, None
|
|
|
| if timestep is not None:
|
| timesteps = timestep
|
| timesteps = timesteps.expand(sample.shape[0])
|
| t_emb = self.time_proj(timesteps)
|
| t_emb = t_emb.to(dtype=sample.dtype)
|
|
|
| t_emb = self.time_embedding(t_emb, timestep_cond)
|
|
|
|
|
| if trans is not None:
|
| trans = trans.expand(sample.shape[0])
|
| op_emb = self.time_proj(trans)
|
| op_emb = op_emb.to(dtype=sample.dtype)
|
|
|
| op_emb = self.time_embedding(op_emb, timestep_cond)
|
|
|
| if t_emb is not None and op_emb is not None:
|
| emb = t_emb + op_emb
|
| elif op_emb is not None:
|
| emb = op_emb
|
| elif t_emb is not None:
|
| emb = t_emb
|
| else:
|
| raise ValueError("Missing required field: 'timestep' and 'trans'. Please ensure it is included in your input.")
|
|
|
| if "point_coords" in added_cond_kwargs:
|
| coords_embeds = added_cond_kwargs.get("point_coords")
|
| coords_embeds = coords_embeds.reshape((sample.shape[0], -1))
|
| coords_embeds = coords_embeds.to(emb.dtype)
|
| aug_emb = self.point_embedding(coords_embeds)
|
| elif "bbox_mask_coords" in added_cond_kwargs:
|
| coords = added_cond_kwargs.get("bbox_mask_coords")
|
| coords_embeds = self.bbox_time_proj(coords.flatten())
|
| coords_embeds = coords_embeds.reshape((sample.shape[0], -1))
|
| coords_embeds = coords_embeds.to(emb.dtype)
|
| aug_emb = self.bbox_embedding(coords_embeds)
|
| else:
|
| raise ValueError(f"{self.__class__} cannot find point_coords or bbox_coords in added_cond_kwargs.")
|
|
|
| emb = emb + aug_emb if aug_emb is not None else emb
|
|
|
|
|
| sample = self.conv_in(sample)
|
|
|
|
|
| self.feature_map = []
|
|
|
|
|
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
| if USE_PEFT_BACKEND:
|
| scale_lora_layers(self, lora_scale)
|
|
|
| down_block_res_samples = (sample,)
|
| for downsample_block in self.down_blocks:
|
| if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| additional_residuals = {}
|
| sample, res_samples = downsample_block(
|
| hidden_states=sample,
|
| temb=emb,
|
| encoder_hidden_states=down_encoder_hidden_states,
|
| attention_mask=down_attn_mask,
|
| cross_attention_kwargs=cross_attention_kwargs,
|
| encoder_attention_mask=encoder_attention_mask,
|
| **additional_residuals,
|
| )
|
| else:
|
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
|
|
| down_block_res_samples += res_samples
|
|
|
| self.feature_map.append(sample)
|
|
|
|
|
| if self.mid_block is not None:
|
| if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
| sample = self.mid_block(
|
| sample,
|
| emb,
|
| encoder_hidden_states=mid_encoder_hidden_states,
|
| attention_mask=mid_attn_mask,
|
| cross_attention_kwargs=cross_attention_kwargs,
|
| encoder_attention_mask=encoder_attention_mask,
|
| )
|
| else:
|
| sample = self.mid_block(sample, emb)
|
|
|
| self.feature_map.append(sample)
|
|
|
|
|
| for i, upsample_block in enumerate(self.up_blocks):
|
| is_final_block = i == len(self.up_blocks) - 1
|
|
|
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
|
|
| if not is_final_block and forward_upsample_size:
|
| upsample_size = down_block_res_samples[-1].shape[2:]
|
|
|
| if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| sample = upsample_block(
|
| hidden_states=sample,
|
| temb=emb,
|
| res_hidden_states_tuple=res_samples,
|
| encoder_hidden_states=up_encoder_hidden_states,
|
| cross_attention_kwargs=cross_attention_kwargs,
|
| upsample_size=upsample_size,
|
| attention_mask=up_attn_mask,
|
| encoder_attention_mask=encoder_attention_mask,
|
| )
|
| else:
|
| sample = upsample_block(
|
| hidden_states=sample,
|
| temb=emb,
|
| res_hidden_states_tuple=res_samples,
|
| upsample_size=upsample_size,
|
| scale=lora_scale,
|
| )
|
|
|
| self.feature_map.append(sample)
|
|
|
|
|
| if self.conv_norm_out:
|
| sample = self.conv_norm_out(sample)
|
| sample = self.conv_act(sample)
|
| sample = self.conv_out(sample)
|
|
|
| if USE_PEFT_BACKEND:
|
| unscale_lora_layers(self, lora_scale)
|
|
|
| return UNet2DConditionOutput(sample=sample)
|
|
|