| from diffusers import UNet2DModel |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, Union |
| from collections import OrderedDict |
| from dataclasses import dataclass |
| from datasets import load_dataset |
| import matplotlib.pyplot as plt |
| from torchvision import transforms |
| from functools import partial |
| import torch |
| from torch.utils.data import DataLoader |
| from PIL import Image |
| from diffusers import DDPMScheduler |
| import torch.nn.functional as F |
|
|
|
|
| class BaseOutput(OrderedDict): |
| """ |
| Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a |
| tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular |
| Python dictionary. |
| """ |
| def __init_subclass__(cls) -> None: |
| if torch.__version__ >= "2.2": |
| import torch.utils._pytree as pytree |
| pytree.register_pytree_node( |
| cls, |
| pytree._dict_flatten, |
| lambda values, context: cls(**pytree._dict_unflatten(values, context)), |
| serialized_type_name=f"{cls.__module__}.{cls.__name__}", |
| ) |
| else: |
| import torch.utils._pytree as pytree |
| pytree._register_pytree_node( |
| cls, |
| pytree._dict_flatten, |
| lambda values, context: cls(**pytree._dict_unflatten(values, context)), |
| ) |
|
|
| @dataclass |
| class UNet2DOutput(BaseOutput): |
| """ |
| The output of [`UNet2DModel`]. |
| |
| Args: |
| sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): |
| The hidden states output from the last layer of the model. |
| """ |
| sample: torch.Tensor |
|
|
|
|
| class DPM(UNet2DModel): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| hidden_size = self.config.block_out_channels[-1] |
| self.bottleneck_attn = nn.MultiheadAttention( |
| embed_dim=hidden_size, |
| num_heads=8, |
| batch_first=True |
| ) |
|
|
|
|
| def forward( |
| self, |
| sample: torch.Tensor, |
| timestep: Union[torch.Tensor, float, int], |
| class_labels: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| prototype: Optional[torch.Tensor] = None, |
| ) -> Union[UNet2DOutput, Tuple]: |
| r""" |
| The [`UNet2DModel`] forward method. |
| |
| Args: |
| sample (`torch.Tensor`): |
| The noisy input tensor with the following shape `(batch, channel, height, width)`. |
| timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. |
| class_labels (`torch.Tensor`, *optional*, defaults to `None`): |
| Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple. |
| |
| Returns: |
| [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`: |
| If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is |
| returned where the first element is the sample tensor. |
| """ |
| |
| if self.config.center_input_sample: |
| sample = 2 * sample - 1.0 |
|
|
| |
| timesteps = timestep |
| if not torch.is_tensor(timesteps): |
| timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) |
| elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: |
| timesteps = timesteps[None].to(sample.device) |
|
|
| |
| timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) |
|
|
| t_emb = self.time_proj(timesteps) |
|
|
| |
| |
| |
| t_emb = t_emb.to(dtype=self.dtype) |
| emb = self.time_embedding(t_emb) |
|
|
| if self.class_embedding is not None: |
| if class_labels is None: |
| raise ValueError("class_labels should be provided when doing class conditioning") |
|
|
| if self.config.class_embed_type == "timestep": |
| class_labels = self.time_proj(class_labels) |
|
|
| class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) |
| emb = emb + class_emb |
| elif self.class_embedding is None and class_labels is not None: |
| raise ValueError("class_embedding needs to be initialized in order to use class conditioning") |
|
|
| |
| skip_sample = sample |
| sample = self.conv_in(sample) |
|
|
| |
| down_block_res_samples = (sample,) |
| for downsample_block in self.down_blocks: |
| if hasattr(downsample_block, "skip_conv"): |
| sample, res_samples, skip_sample = downsample_block( |
| hidden_states=sample, temb=emb, skip_sample=skip_sample |
| ) |
| else: |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
| down_block_res_samples += res_samples |
|
|
| |
| if prototype is None: |
| raise ValueError("You must provide a `prototype` tensor for cross-attention") |
|
|
| b, c, h, w = sample.shape |
| query = sample.view(b, c, h * w).transpose(1, 2) |
|
|
| |
| key = value = prototype.to(dtype=sample.dtype) |
|
|
| attn_output, _ = self.bottleneck_attn(query, key, value) |
| attn_output = attn_output.transpose(1, 2).view(b, c, h, w) |
|
|
| |
| sample = sample + attn_output |
| |
|
|
|
|
| |
| if self.mid_block is not None: |
| sample = self.mid_block(sample, emb) |
|
|
| |
| skip_sample = None |
| for upsample_block in self.up_blocks: |
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
| if hasattr(upsample_block, "skip_conv"): |
| sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) |
| else: |
| sample = upsample_block(sample, res_samples, emb) |
|
|
| |
| sample = self.conv_norm_out(sample) |
| sample = self.conv_act(sample) |
| sample = self.conv_out(sample) |
|
|
| if skip_sample is not None: |
| sample += skip_sample |
|
|
| if self.config.time_embedding_type == "fourier": |
| timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) |
| sample = sample / timesteps |
|
|
| if not return_dict: |
| return (sample,) |
|
|
| return UNet2DOutput(sample=sample) |
|
|