| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| """ |
| v1.5とv2.1の相違点は |
| - attention_head_dimがintかlist[int]か |
| - cross_attention_dimが768か1024か |
| - use_linear_projection: trueがない(=False, 1.5)かあるか |
| - upcast_attentionがFalse(1.5)かTrue(2.1)か |
| - (以下は多分無視していい) |
| - sample_sizeが64か96か |
| - dual_cross_attentionがあるかないか |
| - num_class_embedsがあるかないか |
| - only_cross_attentionがあるかないか |
| |
| v1.5 |
| { |
| "_class_name": "UNet2DConditionModel", |
| "_diffusers_version": "0.6.0", |
| "act_fn": "silu", |
| "attention_head_dim": 8, |
| "block_out_channels": [ |
| 320, |
| 640, |
| 1280, |
| 1280 |
| ], |
| "center_input_sample": false, |
| "cross_attention_dim": 768, |
| "down_block_types": [ |
| "CrossAttnDownBlock2D", |
| "CrossAttnDownBlock2D", |
| "CrossAttnDownBlock2D", |
| "DownBlock2D" |
| ], |
| "downsample_padding": 1, |
| "flip_sin_to_cos": true, |
| "freq_shift": 0, |
| "in_channels": 4, |
| "layers_per_block": 2, |
| "mid_block_scale_factor": 1, |
| "norm_eps": 1e-05, |
| "norm_num_groups": 32, |
| "out_channels": 4, |
| "sample_size": 64, |
| "up_block_types": [ |
| "UpBlock2D", |
| "CrossAttnUpBlock2D", |
| "CrossAttnUpBlock2D", |
| "CrossAttnUpBlock2D" |
| ] |
| } |
| |
| v2.1 |
| { |
| "_class_name": "UNet2DConditionModel", |
| "_diffusers_version": "0.10.0.dev0", |
| "act_fn": "silu", |
| "attention_head_dim": [ |
| 5, |
| 10, |
| 20, |
| 20 |
| ], |
| "block_out_channels": [ |
| 320, |
| 640, |
| 1280, |
| 1280 |
| ], |
| "center_input_sample": false, |
| "cross_attention_dim": 1024, |
| "down_block_types": [ |
| "CrossAttnDownBlock2D", |
| "CrossAttnDownBlock2D", |
| "CrossAttnDownBlock2D", |
| "DownBlock2D" |
| ], |
| "downsample_padding": 1, |
| "dual_cross_attention": false, |
| "flip_sin_to_cos": true, |
| "freq_shift": 0, |
| "in_channels": 4, |
| "layers_per_block": 2, |
| "mid_block_scale_factor": 1, |
| "norm_eps": 1e-05, |
| "norm_num_groups": 32, |
| "num_class_embeds": null, |
| "only_cross_attention": false, |
| "out_channels": 4, |
| "sample_size": 96, |
| "up_block_types": [ |
| "UpBlock2D", |
| "CrossAttnUpBlock2D", |
| "CrossAttnUpBlock2D", |
| "CrossAttnUpBlock2D" |
| ], |
| "use_linear_projection": true, |
| "upcast_attention": true |
| } |
| """ |
|
|
| import math |
| from types import SimpleNamespace |
| from typing import Dict, Optional, Tuple, Union |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from einops import rearrange |
| from .utils import setup_logging |
| setup_logging() |
| import logging |
| logger = logging.getLogger(__name__) |
|
|
| BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) |
| TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] |
| TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4 |
| IN_CHANNELS: int = 4 |
| OUT_CHANNELS: int = 4 |
| LAYERS_PER_BLOCK: int = 2 |
| LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1 |
| TIME_EMBED_FLIP_SIN_TO_COS: bool = True |
| TIME_EMBED_FREQ_SHIFT: int = 0 |
| NORM_GROUPS: int = 32 |
| NORM_EPS: float = 1e-5 |
| TRANSFORMER_NORM_NUM_GROUPS = 32 |
|
|
| DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] |
| UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] |
|
|
|
|
| |
|
|
| |
| |
| |
|
|
| |
|
|
| EPSILON = 1e-6 |
|
|
| |
|
|
|
|
| def exists(val): |
| return val is not None |
|
|
|
|
| def default(val, d): |
| return val if exists(val) else d |
|
|
|
|
| |
|
|
| |
|
|
|
|
| class FlashAttentionFunction(torch.autograd.Function): |
| @staticmethod |
| @torch.no_grad() |
| def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): |
| """Algorithm 2 in the paper""" |
|
|
| device = q.device |
| dtype = q.dtype |
| max_neg_value = -torch.finfo(q.dtype).max |
| qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) |
|
|
| o = torch.zeros_like(q) |
| all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) |
| all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) |
|
|
| scale = q.shape[-1] ** -0.5 |
|
|
| if not exists(mask): |
| mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) |
| else: |
| mask = rearrange(mask, "b n -> b 1 1 n") |
| mask = mask.split(q_bucket_size, dim=-1) |
|
|
| row_splits = zip( |
| q.split(q_bucket_size, dim=-2), |
| o.split(q_bucket_size, dim=-2), |
| mask, |
| all_row_sums.split(q_bucket_size, dim=-2), |
| all_row_maxes.split(q_bucket_size, dim=-2), |
| ) |
|
|
| for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): |
| q_start_index = ind * q_bucket_size - qk_len_diff |
|
|
| col_splits = zip( |
| k.split(k_bucket_size, dim=-2), |
| v.split(k_bucket_size, dim=-2), |
| ) |
|
|
| for k_ind, (kc, vc) in enumerate(col_splits): |
| k_start_index = k_ind * k_bucket_size |
|
|
| attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale |
|
|
| if exists(row_mask): |
| attn_weights.masked_fill_(~row_mask, max_neg_value) |
|
|
| if causal and q_start_index < (k_start_index + k_bucket_size - 1): |
| causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( |
| q_start_index - k_start_index + 1 |
| ) |
| attn_weights.masked_fill_(causal_mask, max_neg_value) |
|
|
| block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) |
| attn_weights -= block_row_maxes |
| exp_weights = torch.exp(attn_weights) |
|
|
| if exists(row_mask): |
| exp_weights.masked_fill_(~row_mask, 0.0) |
|
|
| block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) |
|
|
| new_row_maxes = torch.maximum(block_row_maxes, row_maxes) |
|
|
| exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) |
|
|
| exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) |
| exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) |
|
|
| new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums |
|
|
| oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) |
|
|
| row_maxes.copy_(new_row_maxes) |
| row_sums.copy_(new_row_sums) |
|
|
| ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) |
| ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) |
|
|
| return o |
|
|
| @staticmethod |
| @torch.no_grad() |
| def backward(ctx, do): |
| """Algorithm 4 in the paper""" |
|
|
| causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args |
| q, k, v, o, l, m = ctx.saved_tensors |
|
|
| device = q.device |
|
|
| max_neg_value = -torch.finfo(q.dtype).max |
| qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) |
|
|
| dq = torch.zeros_like(q) |
| dk = torch.zeros_like(k) |
| dv = torch.zeros_like(v) |
|
|
| row_splits = zip( |
| q.split(q_bucket_size, dim=-2), |
| o.split(q_bucket_size, dim=-2), |
| do.split(q_bucket_size, dim=-2), |
| mask, |
| l.split(q_bucket_size, dim=-2), |
| m.split(q_bucket_size, dim=-2), |
| dq.split(q_bucket_size, dim=-2), |
| ) |
|
|
| for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): |
| q_start_index = ind * q_bucket_size - qk_len_diff |
|
|
| col_splits = zip( |
| k.split(k_bucket_size, dim=-2), |
| v.split(k_bucket_size, dim=-2), |
| dk.split(k_bucket_size, dim=-2), |
| dv.split(k_bucket_size, dim=-2), |
| ) |
|
|
| for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): |
| k_start_index = k_ind * k_bucket_size |
|
|
| attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale |
|
|
| if causal and q_start_index < (k_start_index + k_bucket_size - 1): |
| causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( |
| q_start_index - k_start_index + 1 |
| ) |
| attn_weights.masked_fill_(causal_mask, max_neg_value) |
|
|
| exp_attn_weights = torch.exp(attn_weights - mc) |
|
|
| if exists(row_mask): |
| exp_attn_weights.masked_fill_(~row_mask, 0.0) |
|
|
| p = exp_attn_weights / lc |
|
|
| dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) |
| dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) |
|
|
| D = (doc * oc).sum(dim=-1, keepdims=True) |
| ds = p * scale * (dp - D) |
|
|
| dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) |
| dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) |
|
|
| dqc.add_(dq_chunk) |
| dkc.add_(dk_chunk) |
| dvc.add_(dv_chunk) |
|
|
| return dq, dk, dv, None, None, None, None |
|
|
|
|
| |
|
|
|
|
| def get_parameter_dtype(parameter: torch.nn.Module): |
| return next(parameter.parameters()).dtype |
|
|
|
|
| def get_parameter_device(parameter: torch.nn.Module): |
| return next(parameter.parameters()).device |
|
|
|
|
| def get_timestep_embedding( |
| timesteps: torch.Tensor, |
| embedding_dim: int, |
| flip_sin_to_cos: bool = False, |
| downscale_freq_shift: float = 1, |
| scale: float = 1, |
| max_period: int = 10000, |
| ): |
| """ |
| This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
| |
| :param timesteps: a 1-D Tensor of N indices, one per batch element. |
| These may be fractional. |
| :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the |
| embeddings. :return: an [N x dim] Tensor of positional embeddings. |
| """ |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
|
|
| half_dim = embedding_dim // 2 |
| exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) |
| exponent = exponent / (half_dim - downscale_freq_shift) |
|
|
| emb = torch.exp(exponent) |
| emb = timesteps[:, None].float() * emb[None, :] |
|
|
| |
| emb = scale * emb |
|
|
| |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
| |
| if flip_sin_to_cos: |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
|
|
| |
| if embedding_dim % 2 == 1: |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
| return emb |
|
|
|
|
| |
| def resize_like(x, target, mode="bicubic", align_corners=False): |
| org_dtype = x.dtype |
| if org_dtype == torch.bfloat16: |
| x = x.to(torch.float32) |
|
|
| if x.shape[-2:] != target.shape[-2:]: |
| if mode == "nearest": |
| x = F.interpolate(x, size=target.shape[-2:], mode=mode) |
| else: |
| x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) |
|
|
| if org_dtype == torch.bfloat16: |
| x = x.to(org_dtype) |
| return x |
|
|
|
|
| class SampleOutput: |
| def __init__(self, sample): |
| self.sample = sample |
|
|
|
|
| class TimestepEmbedding(nn.Module): |
| def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): |
| super().__init__() |
|
|
| self.linear_1 = nn.Linear(in_channels, time_embed_dim) |
| self.act = None |
| if act_fn == "silu": |
| self.act = nn.SiLU() |
| elif act_fn == "mish": |
| self.act = nn.Mish() |
|
|
| if out_dim is not None: |
| time_embed_dim_out = out_dim |
| else: |
| time_embed_dim_out = time_embed_dim |
| self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) |
|
|
| def forward(self, sample): |
| sample = self.linear_1(sample) |
|
|
| if self.act is not None: |
| sample = self.act(sample) |
|
|
| sample = self.linear_2(sample) |
| return sample |
|
|
|
|
| class Timesteps(nn.Module): |
| def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): |
| super().__init__() |
| self.num_channels = num_channels |
| self.flip_sin_to_cos = flip_sin_to_cos |
| self.downscale_freq_shift = downscale_freq_shift |
|
|
| def forward(self, timesteps): |
| t_emb = get_timestep_embedding( |
| timesteps, |
| self.num_channels, |
| flip_sin_to_cos=self.flip_sin_to_cos, |
| downscale_freq_shift=self.downscale_freq_shift, |
| ) |
| return t_emb |
|
|
|
|
| class ResnetBlock2D(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
|
|
| self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True) |
|
|
| self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
|
| self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels) |
|
|
| self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True) |
| self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
|
| |
| self.nonlinearity = lambda x: F.silu(x) |
|
|
| self.use_in_shortcut = self.in_channels != self.out_channels |
|
|
| self.conv_shortcut = None |
| if self.use_in_shortcut: |
| self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) |
|
|
| def forward(self, input_tensor, temb): |
| hidden_states = input_tensor |
|
|
| hidden_states = self.norm1(hidden_states) |
| hidden_states = self.nonlinearity(hidden_states) |
|
|
| hidden_states = self.conv1(hidden_states) |
|
|
| temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] |
| hidden_states = hidden_states + temb |
|
|
| hidden_states = self.norm2(hidden_states) |
| hidden_states = self.nonlinearity(hidden_states) |
|
|
| hidden_states = self.conv2(hidden_states) |
|
|
| if self.conv_shortcut is not None: |
| input_tensor = self.conv_shortcut(input_tensor) |
|
|
| output_tensor = input_tensor + hidden_states |
|
|
| return output_tensor |
|
|
|
|
| class DownBlock2D(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| add_downsample=True, |
| ): |
| super().__init__() |
|
|
| self.has_cross_attention = False |
| resnets = [] |
|
|
| for i in range(LAYERS_PER_BLOCK): |
| in_channels = in_channels if i == 0 else out_channels |
| resnets.append( |
| ResnetBlock2D( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| ) |
| ) |
| self.resnets = nn.ModuleList(resnets) |
|
|
| if add_downsample: |
| self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)] |
| else: |
| self.downsamplers = None |
|
|
| self.gradient_checkpointing = False |
|
|
| def set_use_memory_efficient_attention(self, xformers, mem_eff): |
| pass |
|
|
| def set_use_sdpa(self, sdpa): |
| pass |
|
|
| def forward(self, hidden_states, temb=None): |
| output_states = () |
|
|
| for resnet in self.resnets: |
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| else: |
| hidden_states = resnet(hidden_states, temb) |
|
|
| output_states += (hidden_states,) |
|
|
| if self.downsamplers is not None: |
| for downsampler in self.downsamplers: |
| hidden_states = downsampler(hidden_states) |
|
|
| output_states += (hidden_states,) |
|
|
| return hidden_states, output_states |
|
|
|
|
| class Downsample2D(nn.Module): |
| def __init__(self, channels, out_channels): |
| super().__init__() |
|
|
| self.channels = channels |
| self.out_channels = out_channels |
|
|
| self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) |
|
|
| def forward(self, hidden_states): |
| assert hidden_states.shape[1] == self.channels |
| hidden_states = self.conv(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__( |
| self, |
| query_dim: int, |
| cross_attention_dim: Optional[int] = None, |
| heads: int = 8, |
| dim_head: int = 64, |
| upcast_attention: bool = False, |
| ): |
| super().__init__() |
| inner_dim = dim_head * heads |
| cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim |
| self.upcast_attention = upcast_attention |
|
|
| self.scale = dim_head**-0.5 |
| self.heads = heads |
|
|
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
| self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) |
| self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) |
|
|
| self.to_out = nn.ModuleList([]) |
| self.to_out.append(nn.Linear(inner_dim, query_dim)) |
| |
|
|
| self.use_memory_efficient_attention_xformers = False |
| self.use_memory_efficient_attention_mem_eff = False |
| self.use_sdpa = False |
|
|
| |
| self.processor = None |
|
|
| def set_use_memory_efficient_attention(self, xformers, mem_eff): |
| self.use_memory_efficient_attention_xformers = xformers |
| self.use_memory_efficient_attention_mem_eff = mem_eff |
|
|
| def set_use_sdpa(self, sdpa): |
| self.use_sdpa = sdpa |
|
|
| def reshape_heads_to_batch_dim(self, tensor): |
| batch_size, seq_len, dim = tensor.shape |
| head_size = self.heads |
| tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) |
| tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) |
| return tensor |
|
|
| def reshape_batch_dim_to_heads(self, tensor): |
| batch_size, seq_len, dim = tensor.shape |
| head_size = self.heads |
| tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) |
| tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) |
| return tensor |
|
|
| def set_processor(self): |
| return self.processor |
|
|
| def get_processor(self): |
| return self.processor |
|
|
| def forward(self, hidden_states, context=None, mask=None, **kwargs): |
| if self.processor is not None: |
| ( |
| hidden_states, |
| encoder_hidden_states, |
| attention_mask, |
| ) = translate_attention_names_from_diffusers( |
| hidden_states=hidden_states, context=context, mask=mask, **kwargs |
| ) |
| return self.processor( |
| attn=self, |
| hidden_states=hidden_states, |
| encoder_hidden_states=context, |
| attention_mask=mask, |
| **kwargs |
| ) |
| if self.use_memory_efficient_attention_xformers: |
| return self.forward_memory_efficient_xformers(hidden_states, context, mask) |
| if self.use_memory_efficient_attention_mem_eff: |
| return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) |
| if self.use_sdpa: |
| return self.forward_sdpa(hidden_states, context, mask) |
|
|
| query = self.to_q(hidden_states) |
| context = context if context is not None else hidden_states |
| key = self.to_k(context) |
| value = self.to_v(context) |
|
|
| query = self.reshape_heads_to_batch_dim(query) |
| key = self.reshape_heads_to_batch_dim(key) |
| value = self.reshape_heads_to_batch_dim(value) |
|
|
| hidden_states = self._attention(query, key, value) |
|
|
| |
| hidden_states = self.to_out[0](hidden_states) |
| |
| return hidden_states |
|
|
| def _attention(self, query, key, value): |
| if self.upcast_attention: |
| query = query.float() |
| key = key.float() |
|
|
| attention_scores = torch.baddbmm( |
| torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), |
| query, |
| key.transpose(-1, -2), |
| beta=0, |
| alpha=self.scale, |
| ) |
| attention_probs = attention_scores.softmax(dim=-1) |
|
|
| |
| attention_probs = attention_probs.to(value.dtype) |
|
|
| |
| hidden_states = torch.bmm(attention_probs, value) |
|
|
| |
| hidden_states = self.reshape_batch_dim_to_heads(hidden_states) |
| return hidden_states |
|
|
| |
| def forward_memory_efficient_xformers(self, x, context=None, mask=None): |
| import xformers.ops |
|
|
| h = self.heads |
| q_in = self.to_q(x) |
| context = context if context is not None else x |
| context = context.to(x.dtype) |
| k_in = self.to_k(context) |
| v_in = self.to_v(context) |
|
|
| q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) |
| del q_in, k_in, v_in |
|
|
| q = q.contiguous() |
| k = k.contiguous() |
| v = v.contiguous() |
| out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) |
|
|
| out = rearrange(out, "b n h d -> b n (h d)", h=h) |
|
|
| out = self.to_out[0](out) |
| return out |
|
|
| def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): |
| flash_func = FlashAttentionFunction |
|
|
| q_bucket_size = 512 |
| k_bucket_size = 1024 |
|
|
| h = self.heads |
| q = self.to_q(x) |
| context = context if context is not None else x |
| context = context.to(x.dtype) |
| k = self.to_k(context) |
| v = self.to_v(context) |
| del context, x |
|
|
| q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) |
|
|
| out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) |
|
|
| out = rearrange(out, "b h n d -> b n (h d)") |
|
|
| out = self.to_out[0](out) |
| return out |
|
|
| def forward_sdpa(self, x, context=None, mask=None): |
| h = self.heads |
| q_in = self.to_q(x) |
| context = context if context is not None else x |
| context = context.to(x.dtype) |
| k_in = self.to_k(context) |
| v_in = self.to_v(context) |
|
|
| q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) |
| del q_in, k_in, v_in |
|
|
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) |
|
|
| out = rearrange(out, "b h n d -> b n (h d)", h=h) |
|
|
| out = self.to_out[0](out) |
| return out |
|
|
| def translate_attention_names_from_diffusers( |
| hidden_states: torch.FloatTensor, |
| context: Optional[torch.FloatTensor] = None, |
| mask: Optional[torch.FloatTensor] = None, |
| |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None |
| ): |
| |
| context = context if context is not None else encoder_hidden_states |
|
|
| |
| mask = mask if mask is not None else attention_mask |
|
|
| return hidden_states, context, mask |
|
|
| |
| class GEGLU(nn.Module): |
| r""" |
| A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. |
| |
| Parameters: |
| dim_in (`int`): The number of channels in the input. |
| dim_out (`int`): The number of channels in the output. |
| """ |
|
|
| def __init__(self, dim_in: int, dim_out: int): |
| super().__init__() |
| self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
| def gelu(self, gate): |
| if gate.device.type != "mps": |
| return F.gelu(gate) |
| |
| return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) |
|
|
| def forward(self, hidden_states): |
| hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) |
| return hidden_states * self.gelu(gate) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| ): |
| super().__init__() |
| inner_dim = int(dim * 4) |
|
|
| self.net = nn.ModuleList([]) |
| |
| self.net.append(GEGLU(dim, inner_dim)) |
| |
| self.net.append(nn.Identity()) |
| |
| self.net.append(nn.Linear(inner_dim, dim)) |
|
|
| def forward(self, hidden_states): |
| for module in self.net: |
| hidden_states = module(hidden_states) |
| return hidden_states |
|
|
|
|
| class BasicTransformerBlock(nn.Module): |
| def __init__( |
| self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False |
| ): |
| super().__init__() |
|
|
| |
| self.attn1 = CrossAttention( |
| query_dim=dim, |
| cross_attention_dim=None, |
| heads=num_attention_heads, |
| dim_head=attention_head_dim, |
| upcast_attention=upcast_attention, |
| ) |
| self.ff = FeedForward(dim) |
|
|
| |
| self.attn2 = CrossAttention( |
| query_dim=dim, |
| cross_attention_dim=cross_attention_dim, |
| heads=num_attention_heads, |
| dim_head=attention_head_dim, |
| upcast_attention=upcast_attention, |
| ) |
|
|
| self.norm1 = nn.LayerNorm(dim) |
| self.norm2 = nn.LayerNorm(dim) |
|
|
| |
| self.norm3 = nn.LayerNorm(dim) |
|
|
| def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): |
| self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) |
| self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) |
|
|
| def set_use_sdpa(self, sdpa: bool): |
| self.attn1.set_use_sdpa(sdpa) |
| self.attn2.set_use_sdpa(sdpa) |
|
|
| def forward(self, hidden_states, context=None, timestep=None): |
| |
| norm_hidden_states = self.norm1(hidden_states) |
|
|
| hidden_states = self.attn1(norm_hidden_states) + hidden_states |
|
|
| |
| norm_hidden_states = self.norm2(hidden_states) |
| hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states |
|
|
| |
| hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states |
|
|
| return hidden_states |
|
|
|
|
| class Transformer2DModel(nn.Module): |
| def __init__( |
| self, |
| num_attention_heads: int = 16, |
| attention_head_dim: int = 88, |
| in_channels: Optional[int] = None, |
| cross_attention_dim: Optional[int] = None, |
| use_linear_projection: bool = False, |
| upcast_attention: bool = False, |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| self.num_attention_heads = num_attention_heads |
| self.attention_head_dim = attention_head_dim |
| inner_dim = num_attention_heads * attention_head_dim |
| self.use_linear_projection = use_linear_projection |
|
|
| self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True) |
|
|
| if use_linear_projection: |
| self.proj_in = nn.Linear(in_channels, inner_dim) |
| else: |
| self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) |
|
|
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| inner_dim, |
| num_attention_heads, |
| attention_head_dim, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| ) |
| ] |
| ) |
|
|
| if use_linear_projection: |
| self.proj_out = nn.Linear(in_channels, inner_dim) |
| else: |
| self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) |
|
|
| def set_use_memory_efficient_attention(self, xformers, mem_eff): |
| for transformer in self.transformer_blocks: |
| transformer.set_use_memory_efficient_attention(xformers, mem_eff) |
|
|
| def set_use_sdpa(self, sdpa): |
| for transformer in self.transformer_blocks: |
| transformer.set_use_sdpa(sdpa) |
|
|
| def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): |
| |
| batch, _, height, weight = hidden_states.shape |
| residual = hidden_states |
|
|
| hidden_states = self.norm(hidden_states) |
| if not self.use_linear_projection: |
| hidden_states = self.proj_in(hidden_states) |
| inner_dim = hidden_states.shape[1] |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) |
| else: |
| inner_dim = hidden_states.shape[1] |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) |
| hidden_states = self.proj_in(hidden_states) |
|
|
| |
| for block in self.transformer_blocks: |
| hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) |
|
|
| |
| if not self.use_linear_projection: |
| hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() |
| hidden_states = self.proj_out(hidden_states) |
| else: |
| hidden_states = self.proj_out(hidden_states) |
| hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() |
|
|
| output = hidden_states + residual |
|
|
| if not return_dict: |
| return (output,) |
|
|
| return SampleOutput(sample=output) |
|
|
|
|
| class CrossAttnDownBlock2D(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| add_downsample=True, |
| cross_attention_dim=1280, |
| attn_num_head_channels=1, |
| use_linear_projection=False, |
| upcast_attention=False, |
| ): |
| super().__init__() |
| self.has_cross_attention = True |
| resnets = [] |
| attentions = [] |
|
|
| self.attn_num_head_channels = attn_num_head_channels |
|
|
| for i in range(LAYERS_PER_BLOCK): |
| in_channels = in_channels if i == 0 else out_channels |
|
|
| resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels)) |
| attentions.append( |
| Transformer2DModel( |
| attn_num_head_channels, |
| out_channels // attn_num_head_channels, |
| in_channels=out_channels, |
| cross_attention_dim=cross_attention_dim, |
| use_linear_projection=use_linear_projection, |
| upcast_attention=upcast_attention, |
| ) |
| ) |
| self.attentions = nn.ModuleList(attentions) |
| self.resnets = nn.ModuleList(resnets) |
|
|
| if add_downsample: |
| self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) |
| else: |
| self.downsamplers = None |
|
|
| self.gradient_checkpointing = False |
|
|
| def set_use_memory_efficient_attention(self, xformers, mem_eff): |
| for attn in self.attentions: |
| attn.set_use_memory_efficient_attention(xformers, mem_eff) |
|
|
| def set_use_sdpa(self, sdpa): |
| for attn in self.attentions: |
| attn.set_use_sdpa(sdpa) |
|
|
| def forward(self, hidden_states, temb=None, encoder_hidden_states=None): |
| output_states = () |
|
|
| for resnet, attn in zip(self.resnets, self.attentions): |
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module, return_dict=None): |
| def custom_forward(*inputs): |
| if return_dict is not None: |
| return module(*inputs, return_dict=return_dict) |
| else: |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states |
| )[0] |
| else: |
| hidden_states = resnet(hidden_states, temb) |
| hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample |
|
|
| output_states += (hidden_states,) |
|
|
| if self.downsamplers is not None: |
| for downsampler in self.downsamplers: |
| hidden_states = downsampler(hidden_states) |
|
|
| output_states += (hidden_states,) |
|
|
| return hidden_states, output_states |
|
|
|
|
| class UNetMidBlock2DCrossAttn(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| attn_num_head_channels=1, |
| cross_attention_dim=1280, |
| use_linear_projection=False, |
| ): |
| super().__init__() |
|
|
| self.has_cross_attention = True |
| self.attn_num_head_channels = attn_num_head_channels |
|
|
| |
| resnets = [ |
| ResnetBlock2D( |
| in_channels=in_channels, |
| out_channels=in_channels, |
| ), |
| ResnetBlock2D( |
| in_channels=in_channels, |
| out_channels=in_channels, |
| ), |
| ] |
| attentions = [ |
| Transformer2DModel( |
| attn_num_head_channels, |
| in_channels // attn_num_head_channels, |
| in_channels=in_channels, |
| cross_attention_dim=cross_attention_dim, |
| use_linear_projection=use_linear_projection, |
| ) |
| ] |
|
|
| self.attentions = nn.ModuleList(attentions) |
| self.resnets = nn.ModuleList(resnets) |
|
|
| self.gradient_checkpointing = False |
|
|
| def set_use_memory_efficient_attention(self, xformers, mem_eff): |
| for attn in self.attentions: |
| attn.set_use_memory_efficient_attention(xformers, mem_eff) |
|
|
| def set_use_sdpa(self, sdpa): |
| for attn in self.attentions: |
| attn.set_use_sdpa(sdpa) |
|
|
| def forward(self, hidden_states, temb=None, encoder_hidden_states=None): |
| for i, resnet in enumerate(self.resnets): |
| attn = None if i == 0 else self.attentions[i - 1] |
|
|
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module, return_dict=None): |
| def custom_forward(*inputs): |
| if return_dict is not None: |
| return module(*inputs, return_dict=return_dict) |
| else: |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| if attn is not None: |
| hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states |
| )[0] |
|
|
| hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| else: |
| if attn is not None: |
| hidden_states = attn(hidden_states, encoder_hidden_states).sample |
| hidden_states = resnet(hidden_states, temb) |
|
|
| return hidden_states |
|
|
|
|
| class Upsample2D(nn.Module): |
| def __init__(self, channels, out_channels): |
| super().__init__() |
| self.channels = channels |
| self.out_channels = out_channels |
| self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) |
|
|
| def forward(self, hidden_states, output_size): |
| assert hidden_states.shape[1] == self.channels |
|
|
| |
| |
| |
| dtype = hidden_states.dtype |
| if dtype == torch.bfloat16: |
| hidden_states = hidden_states.to(torch.float32) |
|
|
| |
| if hidden_states.shape[0] >= 64: |
| hidden_states = hidden_states.contiguous() |
|
|
| |
| if output_size is None: |
| hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") |
| else: |
| hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") |
|
|
| |
| if dtype == torch.bfloat16: |
| hidden_states = hidden_states.to(dtype) |
|
|
| hidden_states = self.conv(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class UpBlock2D(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| prev_output_channel: int, |
| out_channels: int, |
| add_upsample=True, |
| ): |
| super().__init__() |
|
|
| self.has_cross_attention = False |
| resnets = [] |
|
|
| for i in range(LAYERS_PER_BLOCK_UP): |
| res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
| resnets.append( |
| ResnetBlock2D( |
| in_channels=resnet_in_channels + res_skip_channels, |
| out_channels=out_channels, |
| ) |
| ) |
|
|
| self.resnets = nn.ModuleList(resnets) |
|
|
| if add_upsample: |
| self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) |
| else: |
| self.upsamplers = None |
|
|
| self.gradient_checkpointing = False |
|
|
| def set_use_memory_efficient_attention(self, xformers, mem_eff): |
| pass |
|
|
| def set_use_sdpa(self, sdpa): |
| pass |
|
|
| def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): |
| for resnet in self.resnets: |
| |
| res_hidden_states = res_hidden_states_tuple[-1] |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| else: |
| hidden_states = resnet(hidden_states, temb) |
|
|
| if self.upsamplers is not None: |
| for upsampler in self.upsamplers: |
| hidden_states = upsampler(hidden_states, upsample_size) |
|
|
| return hidden_states |
|
|
|
|
| class CrossAttnUpBlock2D(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| prev_output_channel: int, |
| attn_num_head_channels=1, |
| cross_attention_dim=1280, |
| add_upsample=True, |
| use_linear_projection=False, |
| upcast_attention=False, |
| ): |
| super().__init__() |
| resnets = [] |
| attentions = [] |
|
|
| self.has_cross_attention = True |
| self.attn_num_head_channels = attn_num_head_channels |
|
|
| for i in range(LAYERS_PER_BLOCK_UP): |
| res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
| resnets.append( |
| ResnetBlock2D( |
| in_channels=resnet_in_channels + res_skip_channels, |
| out_channels=out_channels, |
| ) |
| ) |
| attentions.append( |
| Transformer2DModel( |
| attn_num_head_channels, |
| out_channels // attn_num_head_channels, |
| in_channels=out_channels, |
| cross_attention_dim=cross_attention_dim, |
| use_linear_projection=use_linear_projection, |
| upcast_attention=upcast_attention, |
| ) |
| ) |
|
|
| self.attentions = nn.ModuleList(attentions) |
| self.resnets = nn.ModuleList(resnets) |
|
|
| if add_upsample: |
| self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) |
| else: |
| self.upsamplers = None |
|
|
| self.gradient_checkpointing = False |
|
|
| def set_use_memory_efficient_attention(self, xformers, mem_eff): |
| for attn in self.attentions: |
| attn.set_use_memory_efficient_attention(xformers, mem_eff) |
|
|
| def set_use_sdpa(self, sdpa): |
| for attn in self.attentions: |
| attn.set_use_sdpa(sdpa) |
|
|
| def forward( |
| self, |
| hidden_states, |
| res_hidden_states_tuple, |
| temb=None, |
| encoder_hidden_states=None, |
| upsample_size=None, |
| ): |
| for resnet, attn in zip(self.resnets, self.attentions): |
| |
| res_hidden_states = res_hidden_states_tuple[-1] |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module, return_dict=None): |
| def custom_forward(*inputs): |
| if return_dict is not None: |
| return module(*inputs, return_dict=return_dict) |
| else: |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) |
| hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states |
| )[0] |
| else: |
| hidden_states = resnet(hidden_states, temb) |
| hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample |
|
|
| if self.upsamplers is not None: |
| for upsampler in self.upsamplers: |
| hidden_states = upsampler(hidden_states, upsample_size) |
|
|
| return hidden_states |
|
|
|
|
| def get_down_block( |
| down_block_type, |
| in_channels, |
| out_channels, |
| add_downsample, |
| attn_num_head_channels, |
| cross_attention_dim, |
| use_linear_projection, |
| upcast_attention, |
| ): |
| if down_block_type == "DownBlock2D": |
| return DownBlock2D( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| add_downsample=add_downsample, |
| ) |
| elif down_block_type == "CrossAttnDownBlock2D": |
| return CrossAttnDownBlock2D( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| add_downsample=add_downsample, |
| cross_attention_dim=cross_attention_dim, |
| attn_num_head_channels=attn_num_head_channels, |
| use_linear_projection=use_linear_projection, |
| upcast_attention=upcast_attention, |
| ) |
|
|
|
|
| def get_up_block( |
| up_block_type, |
| in_channels, |
| out_channels, |
| prev_output_channel, |
| add_upsample, |
| attn_num_head_channels, |
| cross_attention_dim=None, |
| use_linear_projection=False, |
| upcast_attention=False, |
| ): |
| if up_block_type == "UpBlock2D": |
| return UpBlock2D( |
| in_channels=in_channels, |
| prev_output_channel=prev_output_channel, |
| out_channels=out_channels, |
| add_upsample=add_upsample, |
| ) |
| elif up_block_type == "CrossAttnUpBlock2D": |
| return CrossAttnUpBlock2D( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| prev_output_channel=prev_output_channel, |
| attn_num_head_channels=attn_num_head_channels, |
| cross_attention_dim=cross_attention_dim, |
| add_upsample=add_upsample, |
| use_linear_projection=use_linear_projection, |
| upcast_attention=upcast_attention, |
| ) |
|
|
|
|
| class UNet2DConditionModel(nn.Module): |
| _supports_gradient_checkpointing = True |
|
|
| def __init__( |
| self, |
| sample_size: Optional[int] = None, |
| attention_head_dim: Union[int, Tuple[int]] = 8, |
| cross_attention_dim: int = 1280, |
| use_linear_projection: bool = False, |
| upcast_attention: bool = False, |
| **kwargs, |
| ): |
| super().__init__() |
| assert sample_size is not None, "sample_size must be specified" |
| logger.info( |
| f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" |
| ) |
|
|
| |
| self.in_channels = IN_CHANNELS |
| self.out_channels = OUT_CHANNELS |
|
|
| self.sample_size = sample_size |
| self.prepare_config(sample_size=sample_size) |
|
|
| |
|
|
| |
| self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1)) |
|
|
| |
| self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT) |
|
|
| self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM) |
|
|
| self.down_blocks = nn.ModuleList([]) |
| self.mid_block = None |
| self.up_blocks = nn.ModuleList([]) |
|
|
| if isinstance(attention_head_dim, int): |
| attention_head_dim = (attention_head_dim,) * 4 |
|
|
| |
| output_channel = BLOCK_OUT_CHANNELS[0] |
| for i, down_block_type in enumerate(DOWN_BLOCK_TYPES): |
| input_channel = output_channel |
| output_channel = BLOCK_OUT_CHANNELS[i] |
| is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 |
|
|
| down_block = get_down_block( |
| down_block_type, |
| in_channels=input_channel, |
| out_channels=output_channel, |
| add_downsample=not is_final_block, |
| attn_num_head_channels=attention_head_dim[i], |
| cross_attention_dim=cross_attention_dim, |
| use_linear_projection=use_linear_projection, |
| upcast_attention=upcast_attention, |
| ) |
| self.down_blocks.append(down_block) |
|
|
| |
| self.mid_block = UNetMidBlock2DCrossAttn( |
| in_channels=BLOCK_OUT_CHANNELS[-1], |
| attn_num_head_channels=attention_head_dim[-1], |
| cross_attention_dim=cross_attention_dim, |
| use_linear_projection=use_linear_projection, |
| ) |
|
|
| |
| self.num_upsamplers = 0 |
|
|
| |
| reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS)) |
| reversed_attention_head_dim = list(reversed(attention_head_dim)) |
| output_channel = reversed_block_out_channels[0] |
| for i, up_block_type in enumerate(UP_BLOCK_TYPES): |
| is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 |
|
|
| prev_output_channel = output_channel |
| output_channel = reversed_block_out_channels[i] |
| input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)] |
|
|
| |
| if not is_final_block: |
| add_upsample = True |
| self.num_upsamplers += 1 |
| else: |
| add_upsample = False |
|
|
| up_block = get_up_block( |
| up_block_type, |
| in_channels=input_channel, |
| out_channels=output_channel, |
| prev_output_channel=prev_output_channel, |
| add_upsample=add_upsample, |
| attn_num_head_channels=reversed_attention_head_dim[i], |
| cross_attention_dim=cross_attention_dim, |
| use_linear_projection=use_linear_projection, |
| upcast_attention=upcast_attention, |
| ) |
| self.up_blocks.append(up_block) |
| prev_output_channel = output_channel |
|
|
| |
| self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS) |
| self.conv_act = nn.SiLU() |
| self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) |
|
|
| |
| def prepare_config(self, *args, **kwargs): |
| self.config = SimpleNamespace(**kwargs) |
|
|
| @property |
| def dtype(self) -> torch.dtype: |
| |
| return get_parameter_dtype(self) |
|
|
| @property |
| def device(self) -> torch.device: |
| |
| return get_parameter_device(self) |
|
|
| def set_attention_slice(self, slice_size): |
| raise NotImplementedError("Attention slicing is not supported for this model.") |
|
|
| def is_gradient_checkpointing(self) -> bool: |
| return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) |
|
|
| def enable_gradient_checkpointing(self): |
| self.set_gradient_checkpointing(value=True) |
|
|
| def disable_gradient_checkpointing(self): |
| self.set_gradient_checkpointing(value=False) |
|
|
| def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: |
| modules = self.down_blocks + [self.mid_block] + self.up_blocks |
| for module in modules: |
| module.set_use_memory_efficient_attention(xformers, mem_eff) |
|
|
| def set_use_sdpa(self, sdpa: bool) -> None: |
| modules = self.down_blocks + [self.mid_block] + self.up_blocks |
| for module in modules: |
| module.set_use_sdpa(sdpa) |
|
|
| def set_gradient_checkpointing(self, value=False): |
| modules = self.down_blocks + [self.mid_block] + self.up_blocks |
| for module in modules: |
| logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") |
| module.gradient_checkpointing = value |
|
|
| |
|
|
| def forward( |
| self, |
| sample: torch.FloatTensor, |
| timestep: Union[torch.Tensor, float, int], |
| encoder_hidden_states: torch.Tensor, |
| class_labels: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| mid_block_additional_residual: Optional[torch.Tensor] = None, |
| ) -> Union[Dict, Tuple]: |
| r""" |
| Args: |
| sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor |
| timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps |
| encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a dict instead of a plain tuple. |
| |
| Returns: |
| `SampleOutput` or `tuple`: |
| `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| default_overall_up_factor = 2**self.num_upsamplers |
|
|
| |
| |
| forward_upsample_size = False |
| upsample_size = None |
|
|
| if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): |
| |
| forward_upsample_size = True |
|
|
| |
| timesteps = timestep |
| timesteps = self.handle_unusual_timesteps(sample, timesteps) |
|
|
| t_emb = self.time_proj(timesteps) |
|
|
| |
| |
| |
| |
| |
| |
| t_emb = t_emb.to(dtype=self.dtype) |
| emb = self.time_embedding(t_emb) |
|
|
| |
| sample = self.conv_in(sample) |
|
|
| down_block_res_samples = (sample,) |
| for downsample_block in self.down_blocks: |
| |
| |
| if downsample_block.has_cross_attention: |
| sample, res_samples = downsample_block( |
| hidden_states=sample, |
| temb=emb, |
| encoder_hidden_states=encoder_hidden_states, |
| ) |
| else: |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
| down_block_res_samples += res_samples |
|
|
| |
| if down_block_additional_residuals is not None: |
| down_block_res_samples = list(down_block_res_samples) |
| for i in range(len(down_block_res_samples)): |
| down_block_res_samples[i] += down_block_additional_residuals[i] |
| down_block_res_samples = tuple(down_block_res_samples) |
|
|
| |
| sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) |
|
|
| |
| if mid_block_additional_residual is not None: |
| sample += mid_block_additional_residual |
|
|
| |
| for i, upsample_block in enumerate(self.up_blocks): |
| is_final_block = i == len(self.up_blocks) - 1 |
|
|
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
| |
| |
| if not is_final_block and forward_upsample_size: |
| upsample_size = down_block_res_samples[-1].shape[2:] |
|
|
| if upsample_block.has_cross_attention: |
| sample = upsample_block( |
| hidden_states=sample, |
| temb=emb, |
| res_hidden_states_tuple=res_samples, |
| encoder_hidden_states=encoder_hidden_states, |
| upsample_size=upsample_size, |
| ) |
| else: |
| sample = upsample_block( |
| hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size |
| ) |
|
|
| |
| sample = self.conv_norm_out(sample) |
| sample = self.conv_act(sample) |
| sample = self.conv_out(sample) |
|
|
| if not return_dict: |
| return (sample,) |
|
|
| return SampleOutput(sample=sample) |
|
|
| def handle_unusual_timesteps(self, sample, timesteps): |
| r""" |
| timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。 |
| """ |
| if not torch.is_tensor(timesteps): |
| |
| |
| is_mps = sample.device.type == "mps" |
| if isinstance(timesteps, float): |
| dtype = torch.float32 if is_mps else torch.float64 |
| else: |
| dtype = torch.int32 if is_mps else torch.int64 |
| timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
| elif len(timesteps.shape) == 0: |
| timesteps = timesteps[None].to(sample.device) |
|
|
| |
| timesteps = timesteps.expand(sample.shape[0]) |
|
|
| return timesteps |
|
|
|
|
| class InferUNet2DConditionModel: |
| def __init__(self, original_unet: UNet2DConditionModel): |
| self.delegate = original_unet |
|
|
| |
| |
| self.delegate.forward = self.forward |
|
|
| |
| for up_block in self.delegate.up_blocks: |
| if up_block.__class__.__name__ == "UpBlock2D": |
|
|
| def resnet_wrapper(func, block): |
| def forward(*args, **kwargs): |
| return func(block, *args, **kwargs) |
|
|
| return forward |
|
|
| up_block.forward = resnet_wrapper(self.up_block_forward, up_block) |
|
|
| elif up_block.__class__.__name__ == "CrossAttnUpBlock2D": |
|
|
| def cross_attn_up_wrapper(func, block): |
| def forward(*args, **kwargs): |
| return func(block, *args, **kwargs) |
|
|
| return forward |
|
|
| up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block) |
|
|
| |
| self.ds_depth_1 = None |
| self.ds_depth_2 = None |
| self.ds_timesteps_1 = None |
| self.ds_timesteps_2 = None |
| self.ds_ratio = None |
|
|
| |
| def __getattr__(self, name): |
| return getattr(self.delegate, name) |
|
|
| def __call__(self, *args, **kwargs): |
| return self.delegate(*args, **kwargs) |
|
|
| def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): |
| if ds_depth_1 is None: |
| logger.info("Deep Shrink is disabled.") |
| self.ds_depth_1 = None |
| self.ds_timesteps_1 = None |
| self.ds_depth_2 = None |
| self.ds_timesteps_2 = None |
| self.ds_ratio = None |
| else: |
| logger.info( |
| f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" |
| ) |
| self.ds_depth_1 = ds_depth_1 |
| self.ds_timesteps_1 = ds_timesteps_1 |
| self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 |
| self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 |
| self.ds_ratio = ds_ratio |
|
|
| def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): |
| for resnet in _self.resnets: |
| |
| res_hidden_states = res_hidden_states_tuple[-1] |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
| |
| if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: |
| hidden_states = resize_like(hidden_states, res_hidden_states) |
|
|
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
| hidden_states = resnet(hidden_states, temb) |
|
|
| if _self.upsamplers is not None: |
| for upsampler in _self.upsamplers: |
| hidden_states = upsampler(hidden_states, upsample_size) |
|
|
| return hidden_states |
|
|
| def cross_attn_up_block_forward( |
| self, |
| _self, |
| hidden_states, |
| res_hidden_states_tuple, |
| temb=None, |
| encoder_hidden_states=None, |
| upsample_size=None, |
| ): |
| for resnet, attn in zip(_self.resnets, _self.attentions): |
| |
| res_hidden_states = res_hidden_states_tuple[-1] |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
| |
| if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: |
| hidden_states = resize_like(hidden_states, res_hidden_states) |
|
|
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
| hidden_states = resnet(hidden_states, temb) |
| hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample |
|
|
| if _self.upsamplers is not None: |
| for upsampler in _self.upsamplers: |
| hidden_states = upsampler(hidden_states, upsample_size) |
|
|
| return hidden_states |
|
|
| def forward( |
| self, |
| sample: torch.FloatTensor, |
| timestep: Union[torch.Tensor, float, int], |
| encoder_hidden_states: torch.Tensor, |
| class_labels: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| mid_block_additional_residual: Optional[torch.Tensor] = None, |
| ) -> Union[Dict, Tuple]: |
| r""" |
| current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink. |
| """ |
|
|
| r""" |
| Args: |
| sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor |
| timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps |
| encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a dict instead of a plain tuple. |
| |
| Returns: |
| `SampleOutput` or `tuple`: |
| `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. |
| """ |
|
|
| _self = self.delegate |
|
|
| |
| |
| |
| |
| |
| |
| |
| default_overall_up_factor = 2**_self.num_upsamplers |
|
|
| |
| |
| forward_upsample_size = False |
| upsample_size = None |
|
|
| if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): |
| |
| forward_upsample_size = True |
|
|
| |
| timesteps = timestep |
| timesteps = _self.handle_unusual_timesteps(sample, timesteps) |
|
|
| t_emb = _self.time_proj(timesteps) |
|
|
| |
| |
| |
| |
| |
| |
| t_emb = t_emb.to(dtype=_self.dtype) |
| emb = _self.time_embedding(t_emb) |
|
|
| |
| sample = _self.conv_in(sample) |
|
|
| down_block_res_samples = (sample,) |
| for depth, downsample_block in enumerate(_self.down_blocks): |
| |
| if self.ds_depth_1 is not None: |
| if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( |
| self.ds_depth_2 is not None |
| and depth == self.ds_depth_2 |
| and timesteps[0] < self.ds_timesteps_1 |
| and timesteps[0] >= self.ds_timesteps_2 |
| ): |
| org_dtype = sample.dtype |
| if org_dtype == torch.bfloat16: |
| sample = sample.to(torch.float32) |
| sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) |
|
|
| |
| |
| if downsample_block.has_cross_attention: |
| sample, res_samples = downsample_block( |
| hidden_states=sample, |
| temb=emb, |
| encoder_hidden_states=encoder_hidden_states, |
| ) |
| else: |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
| down_block_res_samples += res_samples |
|
|
| |
| if down_block_additional_residuals is not None: |
| down_block_res_samples = list(down_block_res_samples) |
| for i in range(len(down_block_res_samples)): |
| down_block_res_samples[i] += down_block_additional_residuals[i] |
| down_block_res_samples = tuple(down_block_res_samples) |
|
|
| |
| sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) |
|
|
| |
| if mid_block_additional_residual is not None: |
| sample += mid_block_additional_residual |
|
|
| |
| for i, upsample_block in enumerate(_self.up_blocks): |
| is_final_block = i == len(_self.up_blocks) - 1 |
|
|
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
| |
| |
| if not is_final_block and forward_upsample_size: |
| upsample_size = down_block_res_samples[-1].shape[2:] |
|
|
| if upsample_block.has_cross_attention: |
| sample = upsample_block( |
| hidden_states=sample, |
| temb=emb, |
| res_hidden_states_tuple=res_samples, |
| encoder_hidden_states=encoder_hidden_states, |
| upsample_size=upsample_size, |
| ) |
| else: |
| sample = upsample_block( |
| hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size |
| ) |
|
|
| |
| sample = _self.conv_norm_out(sample) |
| sample = _self.conv_act(sample) |
| sample = _self.conv_out(sample) |
|
|
| if not return_dict: |
| return (sample,) |
|
|
| return SampleOutput(sample=sample) |
|
|