Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
import math
from diffusers import UNet2DConditionModel
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
from diffusers.models.unets.unet_2d_blocks import (
get_down_block,
get_up_block,
get_mid_block,
)
from diffusers.models.activations import get_activation
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
def custom_prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
r"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`):
The attention mask to prepare.
target_length (`int`):
The target length of the attention mask. This is the length of the attention mask after padding.
batch_size (`int`):
The batch size, which is used to repeat the attention mask.
out_dim (`int`, *optional*, defaults to `3`):
The output dimension of the attention mask. Can be either `3` or `4`.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size = self.heads
if attention_mask is None:
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
B = attention_mask.shape[0]
current_size = int(math.sqrt(current_length))
target_size = int(math.sqrt(target_length))
assert current_size**2 == current_length, f"current_length ({current_length}) cannot be squared to an integer size"
assert target_size**2 == target_length, f"target_length ({target_length}) cannot be squared to an integer size"
attention_mask = attention_mask.view(B, -1, current_size, current_size)
attention_mask = F.interpolate(attention_mask, size=(target_size, target_size), mode="nearest")
attention_mask = attention_mask.view(B, 1, target_length)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
def custom_get_attention_scores(self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
r"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
Returns:
`torch.Tensor`: The attention probabilities/scores.
"""
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
# if attention_mask is not None and len(torch.unique(attention_mask)) <= 2:
if attention_mask is not None:
baddbmm_input = attention_mask
beta = 1
else:
baddbmm_input = torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)
beta = 0
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
# if attention_mask is not None and len(torch.unique(attention_mask)) > 2:
# m = 1 - (attention_mask / -10000.0)
# attention_scores = m * attention_scores
del baddbmm_input
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
del attention_scores
attention_probs = attention_probs.to(dtype)
return attention_probs
class CustomUNet(UNet2DConditionModel):
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
dropout: float = 0.0,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_dim: Optional[int] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
bbox_time_embed_dim: Optional[int] = None,
point_embeddings_input_dim: Optional[int] = None,
bbox_embeddings_input_dim: Optional[int] = None,
attention_type: str = "default",
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
use_attention_mask_list=[True, True, True],
use_encoder_hidden_states_list=[True, True, True],
):
super().__init__()
self.use_attention_mask_list = use_attention_mask_list
self.use_encoder_hidden_states_list = use_encoder_hidden_states_list
self.sample_size = sample_size
num_attention_heads = num_attention_heads or attention_head_dim
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
# time
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
self.point_embedding = TimestepEmbedding(point_embeddings_input_dim, time_embed_dim)
self.bbox_time_proj = Timesteps(bbox_time_embed_dim, flip_sin_to_cos, freq_shift)
self.bbox_embedding = TimestepEmbedding(bbox_embeddings_input_dim, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention
only_cross_attention = [only_cross_attention] * len(down_block_types)
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if class_embeddings_concat:
blocks_time_embed_dim = time_embed_dim * 2
else:
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
mid_block_type,
temb_channels=blocks_time_embed_dim,
in_channels=block_out_channels[-1],
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
output_scale_factor=mid_block_scale_factor,
transformer_layers_per_block=transformer_layers_per_block[-1],
num_attention_heads=num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
mid_block_only_cross_attention=mid_block_only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[-1],
dropout=dropout,
)
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = (
list(reversed(transformer_layers_per_block))
if reverse_transformer_layers_per_block is None
else reverse_transformer_layers_per_block
)
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=reversed_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_num_groups is not None:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = get_activation(act_fn)
else:
self.conv_norm_out = None
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
# distillation
self.feature_map = []
def _get_value(self, use_list, true_value, false_value):
down_value = mid_value = up_value = false_value
if use_list[0]:
down_value = true_value
if use_list[1]:
mid_value = true_value
if use_list[2]:
up_value = true_value
return down_value, mid_value, up_value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
trans: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> Union[UNet2DConditionOutput, Tuple]:
default_overall_up_factor = 2**self.num_upsamplers
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
forward_upsample_size = True
break
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
down_attn_mask, mid_attn_mask, up_attn_mask = self._get_value(self.use_attention_mask_list, attention_mask, None)
down_encoder_hidden_states, mid_encoder_hidden_states, up_encoder_hidden_states = self._get_value(
self.use_encoder_hidden_states_list, encoder_hidden_states, encoder_hidden_states_2
)
# 1. time
t_emb, op_emb, aug_emb = None, None, None
if timestep is not None:
timesteps = timestep
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
t_emb = self.time_embedding(t_emb, timestep_cond)
# opacity
if trans is not None:
trans = trans.expand(sample.shape[0])
op_emb = self.time_proj(trans)
op_emb = op_emb.to(dtype=sample.dtype)
op_emb = self.time_embedding(op_emb, timestep_cond)
if t_emb is not None and op_emb is not None:
emb = t_emb + op_emb
elif op_emb is not None:
emb = op_emb
elif t_emb is not None:
emb = t_emb
else:
raise ValueError("Missing required field: 'timestep' and 'trans'. Please ensure it is included in your input.")
if "point_coords" in added_cond_kwargs:
coords_embeds = added_cond_kwargs.get("point_coords")
coords_embeds = coords_embeds.reshape((sample.shape[0], -1))
coords_embeds = coords_embeds.to(emb.dtype)
aug_emb = self.point_embedding(coords_embeds)
elif "bbox_mask_coords" in added_cond_kwargs:
coords = added_cond_kwargs.get("bbox_mask_coords")
coords_embeds = self.bbox_time_proj(coords.flatten())
coords_embeds = coords_embeds.reshape((sample.shape[0], -1))
coords_embeds = coords_embeds.to(emb.dtype)
aug_emb = self.bbox_embedding(coords_embeds)
else:
raise ValueError(f"{self.__class__} cannot find point_coords or bbox_coords in added_cond_kwargs.")
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process
sample = self.conv_in(sample)
# distillation
self.feature_map = []
# 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
additional_residuals = {}
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=down_encoder_hidden_states,
attention_mask=down_attn_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
down_block_res_samples += res_samples
self.feature_map.append(sample)
# 4. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=mid_encoder_hidden_states,
attention_mask=mid_attn_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
self.feature_map.append(sample)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=up_encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=up_attn_mask,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
)
self.feature_map.append(sample)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
return UNet2DConditionOutput(sample=sample)