| | from diffusers import UNet2DModel |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, Tuple, Union |
| | from collections import OrderedDict |
| | from dataclasses import dataclass |
| | from datasets import load_dataset |
| | import matplotlib.pyplot as plt |
| | from torchvision import transforms |
| | from functools import partial |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from PIL import Image |
| | from diffusers import DDPMScheduler |
| | import torch.nn.functional as F |
| | from accelerate import Accelerator |
| | from diffusers import DDPMPipeline |
| | import os |
| | from huggingface_hub import create_repo, upload_folder |
| |
|
| |
|
| | class DPM(UNet2DModel): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | |
| | self.bottleneck_attn = nn.MultiheadAttention( |
| | embed_dim=self.config.block_out_channels[-1], |
| | num_heads=8, |
| | batch_first=True |
| | ) |
| |
|
| | def forward( |
| | self, |
| | sample: torch.Tensor, |
| | timestep: Union[torch.Tensor, float, int], |
| | class_labels: Optional[torch.Tensor] = None, |
| | return_dict: bool = True, |
| | prototype: Optional[torch.Tensor] = None, |
| | ) -> Union[UNet2DOutput, Tuple]: |
| | r""" |
| | The [`UNet2DModel`] forward method. |
| | |
| | Args: |
| | sample (`torch.Tensor`): |
| | The noisy input tensor with the following shape `(batch, channel, height, width)`. |
| | timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. |
| | class_labels (`torch.Tensor`, *optional*, defaults to `None`): |
| | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple. |
| | |
| | Returns: |
| | [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`: |
| | If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is |
| | returned where the first element is the sample tensor. |
| | """ |
| | |
| | if self.config.center_input_sample: |
| | sample = 2 * sample - 1.0 |
| |
|
| | |
| | timesteps = timestep |
| | if not torch.is_tensor(timesteps): |
| | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) |
| | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: |
| | timesteps = timesteps[None].to(sample.device) |
| |
|
| | |
| | timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) |
| |
|
| | t_emb = self.time_proj(timesteps) |
| |
|
| | |
| | |
| | |
| | t_emb = t_emb.to(dtype=self.dtype) |
| | emb = self.time_embedding(t_emb) |
| |
|
| | if self.class_embedding is not None: |
| | if class_labels is None: |
| | raise ValueError("class_labels should be provided when doing class conditioning") |
| |
|
| | if self.config.class_embed_type == "timestep": |
| | class_labels = self.time_proj(class_labels) |
| |
|
| | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) |
| | emb = emb + class_emb |
| | elif self.class_embedding is None and class_labels is not None: |
| | raise ValueError("class_embedding needs to be initialized in order to use class conditioning") |
| |
|
| | |
| | skip_sample = sample |
| | sample = self.conv_in(sample) |
| |
|
| | |
| | down_block_res_samples = (sample,) |
| | for downsample_block in self.down_blocks: |
| | if hasattr(downsample_block, "skip_conv"): |
| | sample, res_samples, skip_sample = downsample_block( |
| | hidden_states=sample, temb=emb, skip_sample=skip_sample |
| | ) |
| | else: |
| | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
| |
|
| | down_block_res_samples += res_samples |
| |
|
| | |
| | if prototype is None: |
| | raise ValueError("You must provide a `prototype` tensor for cross-attention") |
| |
|
| | b, c, h, w = sample.shape |
| | query = sample.view(b, c, h * w).transpose(1, 2) |
| |
|
| | |
| | key = value = prototype.to(dtype=sample.dtype) |
| |
|
| | attn_output, _ = self.bottleneck_attn(query, key, value) |
| | attn_output = attn_output.transpose(1, 2).view(b, c, h, w) |
| |
|
| | |
| | sample = sample + attn_output |
| | |
| |
|
| |
|
| | |
| | if self.mid_block is not None: |
| | sample = self.mid_block(sample, emb) |
| |
|
| | |
| | skip_sample = None |
| | for upsample_block in self.up_blocks: |
| | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
| |
|
| | if hasattr(upsample_block, "skip_conv"): |
| | sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) |
| | else: |
| | sample = upsample_block(sample, res_samples, emb) |
| |
|
| | |
| | sample = self.conv_norm_out(sample) |
| | sample = self.conv_act(sample) |
| | sample = self.conv_out(sample) |
| |
|
| | if skip_sample is not None: |
| | sample += skip_sample |
| |
|
| | if self.config.time_embedding_type == "fourier": |
| | timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) |
| | sample = sample / timesteps |
| |
|
| | if not return_dict: |
| | return (sample,) |
| |
|
| | return UNet2DOutput(sample=sample) |