| | import math
|
| | from abc import abstractmethod
|
| | from functools import partial
|
| | from typing import Iterable
|
| |
|
| | import numpy as np
|
| | import torch as th
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| |
|
| |
|
| | from einops import rearrange
|
| |
|
| | from ...modules.attention import SpatialTransformer
|
| | from ...modules.diffusionmodules.util import (
|
| | avg_pool_nd,
|
| | checkpoint,
|
| | conv_nd,
|
| | linear,
|
| | normalization,
|
| | timestep_embedding,
|
| | zero_module,
|
| | )
|
| | from ...util import default, exists
|
| |
|
| |
|
| |
|
| | def convert_module_to_f16(x):
|
| | pass
|
| |
|
| |
|
| | def convert_module_to_f32(x):
|
| | pass
|
| |
|
| |
|
| |
|
| | class AttentionPool2d(nn.Module):
|
| | """
|
| | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | spacial_dim: int,
|
| | embed_dim: int,
|
| | num_heads_channels: int,
|
| | output_dim: int = None,
|
| | ):
|
| | super().__init__()
|
| | self.positional_embedding = nn.Parameter(
|
| | th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
|
| | )
|
| | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
| | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
| | self.num_heads = embed_dim // num_heads_channels
|
| | self.attention = QKVAttention(self.num_heads)
|
| |
|
| | def forward(self, x):
|
| | b, c, *_spatial = x.shape
|
| | x = x.reshape(b, c, -1)
|
| | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
|
| | x = x + self.positional_embedding[None, :, :].to(x.dtype)
|
| | x = self.qkv_proj(x)
|
| | x = self.attention(x)
|
| | x = self.c_proj(x)
|
| | return x[:, :, 0]
|
| |
|
| |
|
| | class TimestepBlock(nn.Module):
|
| | """
|
| | Any module where forward() takes timestep embeddings as a second argument.
|
| | """
|
| |
|
| | @abstractmethod
|
| | def forward(self, x, emb):
|
| | """
|
| | Apply the module to `x` given `emb` timestep embeddings.
|
| | """
|
| |
|
| |
|
| | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| | """
|
| | A sequential module that passes timestep embeddings to the children that
|
| | support it as an extra input.
|
| | """
|
| |
|
| | def forward(
|
| | self,
|
| | x,
|
| | emb,
|
| | context=None,
|
| | skip_time_mix=False,
|
| | time_context=None,
|
| | num_video_frames=None,
|
| | time_context_cat=None,
|
| | use_crossframe_attention_in_spatial_layers=False,
|
| | ):
|
| | for layer in self:
|
| | if isinstance(layer, TimestepBlock):
|
| | x = layer(x, emb)
|
| | elif isinstance(layer, SpatialTransformer):
|
| | x = layer(x, context)
|
| | else:
|
| | x = layer(x)
|
| | return x
|
| |
|
| |
|
| | class Upsample(nn.Module):
|
| | """
|
| | An upsampling layer with an optional convolution.
|
| | :param channels: channels in the inputs and outputs.
|
| | :param use_conv: a bool determining if a convolution is applied.
|
| | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| | upsampling occurs in the inner-two dimensions.
|
| | """
|
| |
|
| | def __init__(
|
| | self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
|
| | ):
|
| | super().__init__()
|
| | self.channels = channels
|
| | self.out_channels = out_channels or channels
|
| | self.use_conv = use_conv
|
| | self.dims = dims
|
| | self.third_up = third_up
|
| | if use_conv:
|
| | self.conv = conv_nd(
|
| | dims, self.channels, self.out_channels, 3, padding=padding
|
| | )
|
| |
|
| | def forward(self, x):
|
| |
|
| | _dtype = x.dtype
|
| | x = x.to(th.float32)
|
| |
|
| | assert x.shape[1] == self.channels
|
| | if self.dims == 3:
|
| | t_factor = 1 if not self.third_up else 2
|
| | x = F.interpolate(
|
| | x,
|
| | (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
| | mode="nearest",
|
| | )
|
| | else:
|
| | x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| |
|
| | x = x.to(_dtype)
|
| |
|
| | if self.use_conv:
|
| | x = self.conv(x)
|
| | return x
|
| |
|
| |
|
| | class TransposedUpsample(nn.Module):
|
| | "Learned 2x upsampling without padding"
|
| |
|
| | def __init__(self, channels, out_channels=None, ks=5):
|
| | super().__init__()
|
| | self.channels = channels
|
| | self.out_channels = out_channels or channels
|
| |
|
| | self.up = nn.ConvTranspose2d(
|
| | self.channels, self.out_channels, kernel_size=ks, stride=2
|
| | )
|
| |
|
| | def forward(self, x):
|
| | return self.up(x)
|
| |
|
| |
|
| | class Downsample(nn.Module):
|
| | """
|
| | A downsampling layer with an optional convolution.
|
| | :param channels: channels in the inputs and outputs.
|
| | :param use_conv: a bool determining if a convolution is applied.
|
| | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| | downsampling occurs in the inner-two dimensions.
|
| | """
|
| |
|
| | def __init__(
|
| | self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
|
| | ):
|
| | super().__init__()
|
| | self.channels = channels
|
| | self.out_channels = out_channels or channels
|
| | self.use_conv = use_conv
|
| | self.dims = dims
|
| | stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
|
| | if use_conv:
|
| | print(f"Building a Downsample layer with {dims} dims.")
|
| | print(
|
| | f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
|
| | f"kernel-size: 3, stride: {stride}, padding: {padding}"
|
| | )
|
| | if dims == 3:
|
| | print(f" --> Downsampling third axis (time): {third_down}")
|
| | self.op = conv_nd(
|
| | dims,
|
| | self.channels,
|
| | self.out_channels,
|
| | 3,
|
| | stride=stride,
|
| | padding=padding,
|
| | )
|
| | else:
|
| | assert self.channels == self.out_channels
|
| | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
| |
|
| | def forward(self, x):
|
| | assert x.shape[1] == self.channels
|
| | return self.op(x)
|
| |
|
| |
|
| | class ResBlock(TimestepBlock):
|
| | """
|
| | A residual block that can optionally change the number of channels.
|
| | :param channels: the number of input channels.
|
| | :param emb_channels: the number of timestep embedding channels.
|
| | :param dropout: the rate of dropout.
|
| | :param out_channels: if specified, the number of out channels.
|
| | :param use_conv: if True and out_channels is specified, use a spatial
|
| | convolution instead of a smaller 1x1 convolution to change the
|
| | channels in the skip connection.
|
| | :param dims: determines if the signal is 1D, 2D, or 3D.
|
| | :param use_checkpoint: if True, use gradient checkpointing on this module.
|
| | :param up: if True, use this block for upsampling.
|
| | :param down: if True, use this block for downsampling.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | channels,
|
| | emb_channels,
|
| | dropout,
|
| | out_channels=None,
|
| | use_conv=False,
|
| | use_scale_shift_norm=False,
|
| | dims=2,
|
| | use_checkpoint=False,
|
| | up=False,
|
| | down=False,
|
| | kernel_size=3,
|
| | exchange_temb_dims=False,
|
| | skip_t_emb=False,
|
| | ):
|
| | super().__init__()
|
| | self.channels = channels
|
| | self.emb_channels = emb_channels
|
| | self.dropout = dropout
|
| | self.out_channels = out_channels or channels
|
| | self.use_conv = use_conv
|
| | self.use_checkpoint = use_checkpoint
|
| | self.use_scale_shift_norm = use_scale_shift_norm
|
| | self.exchange_temb_dims = exchange_temb_dims
|
| |
|
| | if isinstance(kernel_size, Iterable):
|
| | padding = [k // 2 for k in kernel_size]
|
| | else:
|
| | padding = kernel_size // 2
|
| |
|
| | self.in_layers = nn.Sequential(
|
| | normalization(channels),
|
| | nn.SiLU(),
|
| | conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
| | )
|
| |
|
| | self.updown = up or down
|
| |
|
| | if up:
|
| | self.h_upd = Upsample(channels, False, dims)
|
| | self.x_upd = Upsample(channels, False, dims)
|
| | elif down:
|
| | self.h_upd = Downsample(channels, False, dims)
|
| | self.x_upd = Downsample(channels, False, dims)
|
| | else:
|
| | self.h_upd = self.x_upd = nn.Identity()
|
| |
|
| | self.skip_t_emb = skip_t_emb
|
| | self.emb_out_channels = (
|
| | 2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
| | )
|
| | if self.skip_t_emb:
|
| | print(f"Skipping timestep embedding in {self.__class__.__name__}")
|
| | assert not self.use_scale_shift_norm
|
| | self.emb_layers = None
|
| | self.exchange_temb_dims = False
|
| | else:
|
| | self.emb_layers = nn.Sequential(
|
| | nn.SiLU(),
|
| | linear(
|
| | emb_channels,
|
| | self.emb_out_channels,
|
| | ),
|
| | )
|
| |
|
| | self.out_layers = nn.Sequential(
|
| | normalization(self.out_channels),
|
| | nn.SiLU(),
|
| | nn.Dropout(p=dropout),
|
| | zero_module(
|
| | conv_nd(
|
| | dims,
|
| | self.out_channels,
|
| | self.out_channels,
|
| | kernel_size,
|
| | padding=padding,
|
| | )
|
| | ),
|
| | )
|
| |
|
| | if self.out_channels == channels:
|
| | self.skip_connection = nn.Identity()
|
| | elif use_conv:
|
| | self.skip_connection = conv_nd(
|
| | dims, channels, self.out_channels, kernel_size, padding=padding
|
| | )
|
| | else:
|
| | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
| |
|
| | def forward(self, x, emb):
|
| | """
|
| | Apply the block to a Tensor, conditioned on a timestep embedding.
|
| | :param x: an [N x C x ...] Tensor of features.
|
| | :param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
| | :return: an [N x C x ...] Tensor of outputs.
|
| | """
|
| | return checkpoint(
|
| | self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
| | )
|
| |
|
| | def _forward(self, x, emb):
|
| | if self.updown:
|
| | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| | h = in_rest(x)
|
| | h = self.h_upd(h)
|
| | x = self.x_upd(x)
|
| | h = in_conv(h)
|
| | else:
|
| | h = self.in_layers(x)
|
| |
|
| | if self.skip_t_emb:
|
| | emb_out = th.zeros_like(h)
|
| | else:
|
| | emb_out = self.emb_layers(emb).type(h.dtype)
|
| | while len(emb_out.shape) < len(h.shape):
|
| | emb_out = emb_out[..., None]
|
| | if self.use_scale_shift_norm:
|
| | out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
| | scale, shift = th.chunk(emb_out, 2, dim=1)
|
| | h = out_norm(h) * (1 + scale) + shift
|
| | h = out_rest(h)
|
| | else:
|
| | if self.exchange_temb_dims:
|
| | emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
| | h = h + emb_out
|
| | h = self.out_layers(h)
|
| | return self.skip_connection(x) + h
|
| |
|
| |
|
| | class AttentionBlock(nn.Module):
|
| | """
|
| | An attention block that allows spatial positions to attend to each other.
|
| | Originally ported from here, but adapted to the N-d case.
|
| | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | channels,
|
| | num_heads=1,
|
| | num_head_channels=-1,
|
| | use_checkpoint=False,
|
| | use_new_attention_order=False,
|
| | ):
|
| | super().__init__()
|
| | self.channels = channels
|
| | if num_head_channels == -1:
|
| | self.num_heads = num_heads
|
| | else:
|
| | assert (
|
| | channels % num_head_channels == 0
|
| | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
| | self.num_heads = channels // num_head_channels
|
| | self.use_checkpoint = use_checkpoint
|
| | self.norm = normalization(channels)
|
| | self.qkv = conv_nd(1, channels, channels * 3, 1)
|
| | if use_new_attention_order:
|
| |
|
| | self.attention = QKVAttention(self.num_heads)
|
| | else:
|
| |
|
| | self.attention = QKVAttentionLegacy(self.num_heads)
|
| |
|
| | self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
| |
|
| | def forward(self, x, **kwargs):
|
| |
|
| | return checkpoint(
|
| | self._forward, (x,), self.parameters(), True
|
| | )
|
| |
|
| |
|
| | def _forward(self, x):
|
| | b, c, *spatial = x.shape
|
| | x = x.reshape(b, c, -1)
|
| | qkv = self.qkv(self.norm(x))
|
| | h = self.attention(qkv)
|
| | h = self.proj_out(h)
|
| | return (x + h).reshape(b, c, *spatial)
|
| |
|
| |
|
| | def count_flops_attn(model, _x, y):
|
| | """
|
| | A counter for the `thop` package to count the operations in an
|
| | attention operation.
|
| | Meant to be used like:
|
| | macs, params = thop.profile(
|
| | model,
|
| | inputs=(inputs, timestamps),
|
| | custom_ops={QKVAttention: QKVAttention.count_flops},
|
| | )
|
| | """
|
| | b, c, *spatial = y[0].shape
|
| | num_spatial = int(np.prod(spatial))
|
| |
|
| |
|
| |
|
| | matmul_ops = 2 * b * (num_spatial**2) * c
|
| | model.total_ops += th.DoubleTensor([matmul_ops])
|
| |
|
| |
|
| | class QKVAttentionLegacy(nn.Module):
|
| | """
|
| | A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
| | """
|
| |
|
| | def __init__(self, n_heads):
|
| | super().__init__()
|
| | self.n_heads = n_heads
|
| |
|
| | def forward(self, qkv):
|
| | """
|
| | Apply QKV attention.
|
| | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
| | :return: an [N x (H * C) x T] tensor after attention.
|
| | """
|
| | bs, width, length = qkv.shape
|
| | assert width % (3 * self.n_heads) == 0
|
| | ch = width // (3 * self.n_heads)
|
| | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
| | scale = 1 / math.sqrt(math.sqrt(ch))
|
| | weight = th.einsum(
|
| | "bct,bcs->bts", q * scale, k * scale
|
| | )
|
| | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| | a = th.einsum("bts,bcs->bct", weight, v)
|
| | return a.reshape(bs, -1, length)
|
| |
|
| | @staticmethod
|
| | def count_flops(model, _x, y):
|
| | return count_flops_attn(model, _x, y)
|
| |
|
| |
|
| | class QKVAttention(nn.Module):
|
| | """
|
| | A module which performs QKV attention and splits in a different order.
|
| | """
|
| |
|
| | def __init__(self, n_heads):
|
| | super().__init__()
|
| | self.n_heads = n_heads
|
| |
|
| | def forward(self, qkv):
|
| | """
|
| | Apply QKV attention.
|
| | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
| | :return: an [N x (H * C) x T] tensor after attention.
|
| | """
|
| | bs, width, length = qkv.shape
|
| | assert width % (3 * self.n_heads) == 0
|
| | ch = width // (3 * self.n_heads)
|
| | q, k, v = qkv.chunk(3, dim=1)
|
| | scale = 1 / math.sqrt(math.sqrt(ch))
|
| | weight = th.einsum(
|
| | "bct,bcs->bts",
|
| | (q * scale).view(bs * self.n_heads, ch, length),
|
| | (k * scale).view(bs * self.n_heads, ch, length),
|
| | )
|
| | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
| | return a.reshape(bs, -1, length)
|
| |
|
| | @staticmethod
|
| | def count_flops(model, _x, y):
|
| | return count_flops_attn(model, _x, y)
|
| |
|
| |
|
| | class Timestep(nn.Module):
|
| | def __init__(self, dim):
|
| | super().__init__()
|
| | self.dim = dim
|
| |
|
| | def forward(self, t):
|
| | return timestep_embedding(t, self.dim)
|
| |
|
| |
|
| | class UNetModel(nn.Module):
|
| | """
|
| | The full UNet model with attention and timestep embedding.
|
| | :param in_channels: channels in the input Tensor.
|
| | :param model_channels: base channel count for the model.
|
| | :param out_channels: channels in the output Tensor.
|
| | :param num_res_blocks: number of residual blocks per downsample.
|
| | :param attention_resolutions: a collection of downsample rates at which
|
| | attention will take place. May be a set, list, or tuple.
|
| | For example, if this contains 4, then at 4x downsampling, attention
|
| | will be used.
|
| | :param dropout: the dropout probability.
|
| | :param channel_mult: channel multiplier for each level of the UNet.
|
| | :param conv_resample: if True, use learned convolutions for upsampling and
|
| | downsampling.
|
| | :param dims: determines if the signal is 1D, 2D, or 3D.
|
| | :param num_classes: if specified (as an int), then this model will be
|
| | class-conditional with `num_classes` classes.
|
| | :param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
| | :param num_heads: the number of attention heads in each attention layer.
|
| | :param num_heads_channels: if specified, ignore num_heads and instead use
|
| | a fixed channel width per attention head.
|
| | :param num_heads_upsample: works with num_heads to set a different number
|
| | of heads for upsampling. Deprecated.
|
| | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
| | :param resblock_updown: use residual blocks for up/downsampling.
|
| | :param use_new_attention_order: use a different attention pattern for potentially
|
| | increased efficiency.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | in_channels,
|
| | model_channels,
|
| | out_channels,
|
| | num_res_blocks,
|
| | attention_resolutions,
|
| | dropout=0,
|
| | channel_mult=(1, 2, 4, 8),
|
| | conv_resample=True,
|
| | dims=2,
|
| | num_classes=None,
|
| | use_checkpoint=False,
|
| | use_fp16=False,
|
| | num_heads=-1,
|
| | num_head_channels=-1,
|
| | num_heads_upsample=-1,
|
| | use_scale_shift_norm=False,
|
| | resblock_updown=False,
|
| | use_new_attention_order=False,
|
| | use_spatial_transformer=False,
|
| | transformer_depth=1,
|
| | context_dim=None,
|
| | n_embed=None,
|
| | legacy=True,
|
| | disable_self_attentions=None,
|
| | num_attention_blocks=None,
|
| | disable_middle_self_attn=False,
|
| | use_linear_in_transformer=False,
|
| | spatial_transformer_attn_type="softmax",
|
| | adm_in_channels=None,
|
| | use_fairscale_checkpoint=False,
|
| | offload_to_cpu=False,
|
| | transformer_depth_middle=None,
|
| | ):
|
| | super().__init__()
|
| | from omegaconf.listconfig import ListConfig
|
| |
|
| | if use_spatial_transformer:
|
| | assert (
|
| | context_dim is not None
|
| | ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
| |
|
| | if context_dim is not None:
|
| | assert (
|
| | use_spatial_transformer
|
| | ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
| | if type(context_dim) == ListConfig:
|
| | context_dim = list(context_dim)
|
| |
|
| | if num_heads_upsample == -1:
|
| | num_heads_upsample = num_heads
|
| |
|
| | if num_heads == -1:
|
| | assert (
|
| | num_head_channels != -1
|
| | ), "Either num_heads or num_head_channels has to be set"
|
| |
|
| | if num_head_channels == -1:
|
| | assert (
|
| | num_heads != -1
|
| | ), "Either num_heads or num_head_channels has to be set"
|
| |
|
| | self.in_channels = in_channels
|
| | self.model_channels = model_channels
|
| | self.out_channels = out_channels
|
| | if isinstance(transformer_depth, int):
|
| | transformer_depth = len(channel_mult) * [transformer_depth]
|
| | elif isinstance(transformer_depth, ListConfig):
|
| | transformer_depth = list(transformer_depth)
|
| | transformer_depth_middle = default(
|
| | transformer_depth_middle, transformer_depth[-1]
|
| | )
|
| |
|
| | if isinstance(num_res_blocks, int):
|
| | self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| | else:
|
| | if len(num_res_blocks) != len(channel_mult):
|
| | raise ValueError(
|
| | "provide num_res_blocks either as an int (globally constant) or "
|
| | "as a list/tuple (per-level) with the same length as channel_mult"
|
| | )
|
| | self.num_res_blocks = num_res_blocks
|
| |
|
| | if disable_self_attentions is not None:
|
| |
|
| | assert len(disable_self_attentions) == len(channel_mult)
|
| | if num_attention_blocks is not None:
|
| | assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| | assert all(
|
| | map(
|
| | lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
|
| | range(len(num_attention_blocks)),
|
| | )
|
| | )
|
| | print(
|
| | f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| | f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| | f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| | f"attention will still not be set."
|
| | )
|
| |
|
| | self.attention_resolutions = attention_resolutions
|
| | self.dropout = dropout
|
| | self.channel_mult = channel_mult
|
| | self.conv_resample = conv_resample
|
| | self.num_classes = num_classes
|
| | self.use_checkpoint = use_checkpoint
|
| | if use_fp16:
|
| | print("WARNING: use_fp16 was dropped and has no effect anymore.")
|
| |
|
| | self.num_heads = num_heads
|
| | self.num_head_channels = num_head_channels
|
| | self.num_heads_upsample = num_heads_upsample
|
| | self.predict_codebook_ids = n_embed is not None
|
| |
|
| | assert use_fairscale_checkpoint != use_checkpoint or not (
|
| | use_checkpoint or use_fairscale_checkpoint
|
| | )
|
| |
|
| | self.use_fairscale_checkpoint = False
|
| | checkpoint_wrapper_fn = (
|
| | partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
|
| | if self.use_fairscale_checkpoint
|
| | else lambda x: x
|
| | )
|
| |
|
| | time_embed_dim = model_channels * 4
|
| | self.time_embed = checkpoint_wrapper_fn(
|
| | nn.Sequential(
|
| | linear(model_channels, time_embed_dim),
|
| | nn.SiLU(),
|
| | linear(time_embed_dim, time_embed_dim),
|
| | )
|
| | )
|
| |
|
| | if self.num_classes is not None:
|
| | if isinstance(self.num_classes, int):
|
| | self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| | elif self.num_classes == "continuous":
|
| | print("setting up linear c_adm embedding layer")
|
| | self.label_emb = nn.Linear(1, time_embed_dim)
|
| | elif self.num_classes == "timestep":
|
| | self.label_emb = checkpoint_wrapper_fn(
|
| | nn.Sequential(
|
| | Timestep(model_channels),
|
| | nn.Sequential(
|
| | linear(model_channels, time_embed_dim),
|
| | nn.SiLU(),
|
| | linear(time_embed_dim, time_embed_dim),
|
| | ),
|
| | )
|
| | )
|
| | elif self.num_classes == "sequential":
|
| | assert adm_in_channels is not None
|
| | self.label_emb = nn.Sequential(
|
| | nn.Sequential(
|
| | linear(adm_in_channels, time_embed_dim),
|
| | nn.SiLU(),
|
| | linear(time_embed_dim, time_embed_dim),
|
| | )
|
| | )
|
| | else:
|
| | raise ValueError()
|
| |
|
| | self.input_blocks = nn.ModuleList(
|
| | [
|
| | TimestepEmbedSequential(
|
| | conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| | )
|
| | ]
|
| | )
|
| | self._feature_size = model_channels
|
| | input_block_chans = [model_channels]
|
| | ch = model_channels
|
| | ds = 1
|
| | for level, mult in enumerate(channel_mult):
|
| | for nr in range(self.num_res_blocks[level]):
|
| | layers = [
|
| | checkpoint_wrapper_fn(
|
| | ResBlock(
|
| | ch,
|
| | time_embed_dim,
|
| | dropout,
|
| | out_channels=mult * model_channels,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | )
|
| | )
|
| | ]
|
| | ch = mult * model_channels
|
| | if ds in attention_resolutions:
|
| | if num_head_channels == -1:
|
| | dim_head = ch // num_heads
|
| | else:
|
| | num_heads = ch // num_head_channels
|
| | dim_head = num_head_channels
|
| | if legacy:
|
| |
|
| | dim_head = (
|
| | ch // num_heads
|
| | if use_spatial_transformer
|
| | else num_head_channels
|
| | )
|
| | if exists(disable_self_attentions):
|
| | disabled_sa = disable_self_attentions[level]
|
| | else:
|
| | disabled_sa = False
|
| |
|
| | if (
|
| | not exists(num_attention_blocks)
|
| | or nr < num_attention_blocks[level]
|
| | ):
|
| | layers.append(
|
| | checkpoint_wrapper_fn(
|
| | AttentionBlock(
|
| | ch,
|
| | use_checkpoint=use_checkpoint,
|
| | num_heads=num_heads,
|
| | num_head_channels=dim_head,
|
| | use_new_attention_order=use_new_attention_order,
|
| | )
|
| | )
|
| | if not use_spatial_transformer
|
| | else checkpoint_wrapper_fn(
|
| | SpatialTransformer(
|
| | ch,
|
| | num_heads,
|
| | dim_head,
|
| | depth=transformer_depth[level],
|
| | context_dim=context_dim,
|
| | disable_self_attn=disabled_sa,
|
| | use_linear=use_linear_in_transformer,
|
| | attn_type=spatial_transformer_attn_type,
|
| | use_checkpoint=use_checkpoint,
|
| | )
|
| | )
|
| | )
|
| | self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| | self._feature_size += ch
|
| | input_block_chans.append(ch)
|
| | if level != len(channel_mult) - 1:
|
| | out_ch = ch
|
| | self.input_blocks.append(
|
| | TimestepEmbedSequential(
|
| | checkpoint_wrapper_fn(
|
| | ResBlock(
|
| | ch,
|
| | time_embed_dim,
|
| | dropout,
|
| | out_channels=out_ch,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | down=True,
|
| | )
|
| | )
|
| | if resblock_updown
|
| | else Downsample(
|
| | ch, conv_resample, dims=dims, out_channels=out_ch
|
| | )
|
| | )
|
| | )
|
| | ch = out_ch
|
| | input_block_chans.append(ch)
|
| | ds *= 2
|
| | self._feature_size += ch
|
| |
|
| | if num_head_channels == -1:
|
| | dim_head = ch // num_heads
|
| | else:
|
| | num_heads = ch // num_head_channels
|
| | dim_head = num_head_channels
|
| | if legacy:
|
| |
|
| | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| | self.middle_block = TimestepEmbedSequential(
|
| | checkpoint_wrapper_fn(
|
| | ResBlock(
|
| | ch,
|
| | time_embed_dim,
|
| | dropout,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | )
|
| | ),
|
| | checkpoint_wrapper_fn(
|
| | AttentionBlock(
|
| | ch,
|
| | use_checkpoint=use_checkpoint,
|
| | num_heads=num_heads,
|
| | num_head_channels=dim_head,
|
| | use_new_attention_order=use_new_attention_order,
|
| | )
|
| | )
|
| | if not use_spatial_transformer
|
| | else checkpoint_wrapper_fn(
|
| | SpatialTransformer(
|
| | ch,
|
| | num_heads,
|
| | dim_head,
|
| | depth=transformer_depth_middle,
|
| | context_dim=context_dim,
|
| | disable_self_attn=disable_middle_self_attn,
|
| | use_linear=use_linear_in_transformer,
|
| | attn_type=spatial_transformer_attn_type,
|
| | use_checkpoint=use_checkpoint,
|
| | )
|
| | ),
|
| | checkpoint_wrapper_fn(
|
| | ResBlock(
|
| | ch,
|
| | time_embed_dim,
|
| | dropout,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | )
|
| | ),
|
| | )
|
| | self._feature_size += ch
|
| |
|
| | self.output_blocks = nn.ModuleList([])
|
| | for level, mult in list(enumerate(channel_mult))[::-1]:
|
| | for i in range(self.num_res_blocks[level] + 1):
|
| | ich = input_block_chans.pop()
|
| | layers = [
|
| | checkpoint_wrapper_fn(
|
| | ResBlock(
|
| | ch + ich,
|
| | time_embed_dim,
|
| | dropout,
|
| | out_channels=model_channels * mult,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | )
|
| | )
|
| | ]
|
| | ch = model_channels * mult
|
| | if ds in attention_resolutions:
|
| | if num_head_channels == -1:
|
| | dim_head = ch // num_heads
|
| | else:
|
| | num_heads = ch // num_head_channels
|
| | dim_head = num_head_channels
|
| | if legacy:
|
| |
|
| | dim_head = (
|
| | ch // num_heads
|
| | if use_spatial_transformer
|
| | else num_head_channels
|
| | )
|
| | if exists(disable_self_attentions):
|
| | disabled_sa = disable_self_attentions[level]
|
| | else:
|
| | disabled_sa = False
|
| |
|
| | if (
|
| | not exists(num_attention_blocks)
|
| | or i < num_attention_blocks[level]
|
| | ):
|
| | layers.append(
|
| | checkpoint_wrapper_fn(
|
| | AttentionBlock(
|
| | ch,
|
| | use_checkpoint=use_checkpoint,
|
| | num_heads=num_heads_upsample,
|
| | num_head_channels=dim_head,
|
| | use_new_attention_order=use_new_attention_order,
|
| | )
|
| | )
|
| | if not use_spatial_transformer
|
| | else checkpoint_wrapper_fn(
|
| | SpatialTransformer(
|
| | ch,
|
| | num_heads,
|
| | dim_head,
|
| | depth=transformer_depth[level],
|
| | context_dim=context_dim,
|
| | disable_self_attn=disabled_sa,
|
| | use_linear=use_linear_in_transformer,
|
| | attn_type=spatial_transformer_attn_type,
|
| | use_checkpoint=use_checkpoint,
|
| | )
|
| | )
|
| | )
|
| | if level and i == self.num_res_blocks[level]:
|
| | out_ch = ch
|
| | layers.append(
|
| | checkpoint_wrapper_fn(
|
| | ResBlock(
|
| | ch,
|
| | time_embed_dim,
|
| | dropout,
|
| | out_channels=out_ch,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | up=True,
|
| | )
|
| | )
|
| | if resblock_updown
|
| | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
| | )
|
| | ds //= 2
|
| | self.output_blocks.append(TimestepEmbedSequential(*layers))
|
| | self._feature_size += ch
|
| |
|
| | self.out = checkpoint_wrapper_fn(
|
| | nn.Sequential(
|
| | normalization(ch),
|
| | nn.SiLU(),
|
| | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
| | )
|
| | )
|
| | if self.predict_codebook_ids:
|
| | self.id_predictor = checkpoint_wrapper_fn(
|
| | nn.Sequential(
|
| | normalization(ch),
|
| | conv_nd(dims, model_channels, n_embed, 1),
|
| |
|
| | )
|
| | )
|
| |
|
| | def convert_to_fp16(self):
|
| | """
|
| | Convert the torso of the model to float16.
|
| | """
|
| | self.input_blocks.apply(convert_module_to_f16)
|
| | self.middle_block.apply(convert_module_to_f16)
|
| | self.output_blocks.apply(convert_module_to_f16)
|
| |
|
| | def convert_to_fp32(self):
|
| | """
|
| | Convert the torso of the model to float32.
|
| | """
|
| | self.input_blocks.apply(convert_module_to_f32)
|
| | self.middle_block.apply(convert_module_to_f32)
|
| | self.output_blocks.apply(convert_module_to_f32)
|
| |
|
| | def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
| | """
|
| | Apply the model to an input batch.
|
| | :param x: an [N x C x ...] Tensor of inputs.
|
| | :param timesteps: a 1-D batch of timesteps.
|
| | :param context: conditioning plugged in via crossattn
|
| | :param y: an [N] Tensor of labels, if class-conditional.
|
| | :return: an [N x C x ...] Tensor of outputs.
|
| | """
|
| | assert (y is not None) == (
|
| | self.num_classes is not None
|
| | ), "must specify y if and only if the model is class-conditional"
|
| | hs = []
|
| |
|
| | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
| | emb = self.time_embed(t_emb)
|
| |
|
| | if self.num_classes is not None:
|
| | assert y.shape[0] == x.shape[0]
|
| | emb = emb + self.label_emb(y)
|
| |
|
| |
|
| | h = x
|
| | for module in self.input_blocks:
|
| | h = module(h, emb, context)
|
| | hs.append(h)
|
| | h = self.middle_block(h, emb, context)
|
| | for module in self.output_blocks:
|
| | h = th.cat([h, hs.pop()], dim=1)
|
| | h = module(h, emb, context)
|
| | h = h.type(x.dtype)
|
| | if self.predict_codebook_ids:
|
| | assert False, "not supported anymore. what the f*** are you doing?"
|
| | else:
|
| | return self.out(h)
|
| |
|
| |
|
| | class NoTimeUNetModel(UNetModel):
|
| | def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
| | timesteps = th.zeros_like(timesteps)
|
| | return super().forward(x, timesteps, context, y, **kwargs)
|
| |
|
| |
|
| | class EncoderUNetModel(nn.Module):
|
| | """
|
| | The half UNet model with attention and timestep embedding.
|
| | For usage, see UNet.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | image_size,
|
| | in_channels,
|
| | model_channels,
|
| | out_channels,
|
| | num_res_blocks,
|
| | attention_resolutions,
|
| | dropout=0,
|
| | channel_mult=(1, 2, 4, 8),
|
| | conv_resample=True,
|
| | dims=2,
|
| | use_checkpoint=False,
|
| | use_fp16=False,
|
| | num_heads=1,
|
| | num_head_channels=-1,
|
| | num_heads_upsample=-1,
|
| | use_scale_shift_norm=False,
|
| | resblock_updown=False,
|
| | use_new_attention_order=False,
|
| | pool="adaptive",
|
| | *args,
|
| | **kwargs,
|
| | ):
|
| | super().__init__()
|
| |
|
| | if num_heads_upsample == -1:
|
| | num_heads_upsample = num_heads
|
| |
|
| | self.in_channels = in_channels
|
| | self.model_channels = model_channels
|
| | self.out_channels = out_channels
|
| | self.num_res_blocks = num_res_blocks
|
| | self.attention_resolutions = attention_resolutions
|
| | self.dropout = dropout
|
| | self.channel_mult = channel_mult
|
| | self.conv_resample = conv_resample
|
| | self.use_checkpoint = use_checkpoint
|
| | self.dtype = th.float16 if use_fp16 else th.float32
|
| | self.num_heads = num_heads
|
| | self.num_head_channels = num_head_channels
|
| | self.num_heads_upsample = num_heads_upsample
|
| |
|
| | time_embed_dim = model_channels * 4
|
| | self.time_embed = nn.Sequential(
|
| | linear(model_channels, time_embed_dim),
|
| | nn.SiLU(),
|
| | linear(time_embed_dim, time_embed_dim),
|
| | )
|
| |
|
| | self.input_blocks = nn.ModuleList(
|
| | [
|
| | TimestepEmbedSequential(
|
| | conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| | )
|
| | ]
|
| | )
|
| | self._feature_size = model_channels
|
| | input_block_chans = [model_channels]
|
| | ch = model_channels
|
| | ds = 1
|
| | for level, mult in enumerate(channel_mult):
|
| | for _ in range(num_res_blocks):
|
| | layers = [
|
| | ResBlock(
|
| | ch,
|
| | time_embed_dim,
|
| | dropout,
|
| | out_channels=mult * model_channels,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | )
|
| | ]
|
| | ch = mult * model_channels
|
| | if ds in attention_resolutions:
|
| | layers.append(
|
| | AttentionBlock(
|
| | ch,
|
| | use_checkpoint=use_checkpoint,
|
| | num_heads=num_heads,
|
| | num_head_channels=num_head_channels,
|
| | use_new_attention_order=use_new_attention_order,
|
| | )
|
| | )
|
| | self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| | self._feature_size += ch
|
| | input_block_chans.append(ch)
|
| | if level != len(channel_mult) - 1:
|
| | out_ch = ch
|
| | self.input_blocks.append(
|
| | TimestepEmbedSequential(
|
| | ResBlock(
|
| | ch,
|
| | time_embed_dim,
|
| | dropout,
|
| | out_channels=out_ch,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | down=True,
|
| | )
|
| | if resblock_updown
|
| | else Downsample(
|
| | ch, conv_resample, dims=dims, out_channels=out_ch
|
| | )
|
| | )
|
| | )
|
| | ch = out_ch
|
| | input_block_chans.append(ch)
|
| | ds *= 2
|
| | self._feature_size += ch
|
| |
|
| | self.middle_block = TimestepEmbedSequential(
|
| | ResBlock(
|
| | ch,
|
| | time_embed_dim,
|
| | dropout,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | ),
|
| | AttentionBlock(
|
| | ch,
|
| | use_checkpoint=use_checkpoint,
|
| | num_heads=num_heads,
|
| | num_head_channels=num_head_channels,
|
| | use_new_attention_order=use_new_attention_order,
|
| | ),
|
| | ResBlock(
|
| | ch,
|
| | time_embed_dim,
|
| | dropout,
|
| | dims=dims,
|
| | use_checkpoint=use_checkpoint,
|
| | use_scale_shift_norm=use_scale_shift_norm,
|
| | ),
|
| | )
|
| | self._feature_size += ch
|
| | self.pool = pool
|
| | if pool == "adaptive":
|
| | self.out = nn.Sequential(
|
| | normalization(ch),
|
| | nn.SiLU(),
|
| | nn.AdaptiveAvgPool2d((1, 1)),
|
| | zero_module(conv_nd(dims, ch, out_channels, 1)),
|
| | nn.Flatten(),
|
| | )
|
| | elif pool == "attention":
|
| | assert num_head_channels != -1
|
| | self.out = nn.Sequential(
|
| | normalization(ch),
|
| | nn.SiLU(),
|
| | AttentionPool2d(
|
| | (image_size // ds), ch, num_head_channels, out_channels
|
| | ),
|
| | )
|
| | elif pool == "spatial":
|
| | self.out = nn.Sequential(
|
| | nn.Linear(self._feature_size, 2048),
|
| | nn.ReLU(),
|
| | nn.Linear(2048, self.out_channels),
|
| | )
|
| | elif pool == "spatial_v2":
|
| | self.out = nn.Sequential(
|
| | nn.Linear(self._feature_size, 2048),
|
| | normalization(2048),
|
| | nn.SiLU(),
|
| | nn.Linear(2048, self.out_channels),
|
| | )
|
| | else:
|
| | raise NotImplementedError(f"Unexpected {pool} pooling")
|
| |
|
| | def convert_to_fp16(self):
|
| | """
|
| | Convert the torso of the model to float16.
|
| | """
|
| | self.input_blocks.apply(convert_module_to_f16)
|
| | self.middle_block.apply(convert_module_to_f16)
|
| |
|
| | def convert_to_fp32(self):
|
| | """
|
| | Convert the torso of the model to float32.
|
| | """
|
| | self.input_blocks.apply(convert_module_to_f32)
|
| | self.middle_block.apply(convert_module_to_f32)
|
| |
|
| | def forward(self, x, timesteps):
|
| | """
|
| | Apply the model to an input batch.
|
| | :param x: an [N x C x ...] Tensor of inputs.
|
| | :param timesteps: a 1-D batch of timesteps.
|
| | :return: an [N x K] Tensor of outputs.
|
| | """
|
| | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
| |
|
| | results = []
|
| |
|
| | h = x
|
| | for module in self.input_blocks:
|
| | h = module(h, emb)
|
| | if self.pool.startswith("spatial"):
|
| | results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
| | h = self.middle_block(h, emb)
|
| | if self.pool.startswith("spatial"):
|
| | results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
| | h = th.cat(results, axis=-1)
|
| | return self.out(h)
|
| | else:
|
| | h = h.type(x.dtype)
|
| | return self.out(h)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | class Dummy(nn.Module):
|
| | def __init__(self, in_channels=3, model_channels=64):
|
| | super().__init__()
|
| | self.input_blocks = nn.ModuleList(
|
| | [
|
| | TimestepEmbedSequential(
|
| | conv_nd(2, in_channels, model_channels, 3, padding=1)
|
| | )
|
| | ]
|
| | )
|
| |
|
| | model = UNetModel(
|
| | use_checkpoint=True,
|
| | image_size=64,
|
| | in_channels=4,
|
| | out_channels=4,
|
| | model_channels=128,
|
| | attention_resolutions=[4, 2],
|
| | num_res_blocks=2,
|
| | channel_mult=[1, 2, 4],
|
| | num_head_channels=64,
|
| | use_spatial_transformer=False,
|
| | use_linear_in_transformer=True,
|
| | transformer_depth=1,
|
| | legacy=False,
|
| | ).cuda()
|
| | x = th.randn(11, 4, 64, 64).cuda()
|
| | t = th.randint(low=0, high=10, size=(11,), device="cuda")
|
| | o = model(x, t)
|
| | print("done.")
|
| |
|