| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from modules.encoder.position_encoder import PositionEncoder |
| | from modules.general.utils import append_dims, ConvNd, normalization, zero_module |
| | from .attention import AttentionBlock |
| | from .resblock import Downsample, ResBlock, Upsample |
| |
|
| |
|
| | class UNet(nn.Module): |
| | r"""The full UNet model with attention and timestep embedding. |
| | |
| | Args: |
| | dims: determines if the signal is 1D (temporal), 2D(spatial). |
| | in_channels: channels in the input Tensor. |
| | model_channels: base channel count for the model. |
| | out_channels: channels in the output Tensor. |
| | num_res_blocks: number of residual blocks per downsample. |
| | channel_mult: channel multiplier for each level of the UNet. |
| | num_attn_blocks: number of attention blocks at place. |
| | attention_resolutions: a collection of downsample rates at which attention will |
| | take place. May be a set, list, or tuple. For example, if this contains 4, |
| | then at 4x downsampling, attention will be used. |
| | num_heads: the number of attention heads in each attention layer. |
| | num_head_channels: if specified, ignore num_heads and instead use a fixed |
| | channel width per attention head. |
| | d_context: if specified, use for cross-attention channel project. |
| | p_dropout: the dropout probability. |
| | use_self_attention: Apply self attention before cross attention. |
| | num_classes: if specified (as an int), then this model will be class-conditional |
| | with ``num_classes`` classes. |
| | use_extra_film: if specified, use an extra FiLM-like conditioning mechanism. |
| | d_emb: if specified, use for FiLM-like conditioning. |
| | use_scale_shift_norm: use a FiLM-like conditioning mechanism. |
| | resblock_updown: use residual blocks for up/downsampling. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dims: int = 1, |
| | in_channels: int = 100, |
| | model_channels: int = 128, |
| | out_channels: int = 100, |
| | h_dim: int = 128, |
| | num_res_blocks: int = 1, |
| | channel_mult: tuple = (1, 2, 4), |
| | num_attn_blocks: int = 1, |
| | attention_resolutions: tuple = (1, 2, 4), |
| | num_heads: int = 1, |
| | num_head_channels: int = -1, |
| | d_context: int = None, |
| | context_hdim: int = 128, |
| | p_dropout: float = 0.0, |
| | num_classes: int = -1, |
| | use_extra_film: str = None, |
| | d_emb: int = None, |
| | use_scale_shift_norm: bool = True, |
| | resblock_updown: bool = False, |
| | ): |
| | super().__init__() |
| |
|
| | self.dims = dims |
| | self.in_channels = in_channels |
| | self.model_channels = model_channels |
| | self.out_channels = out_channels |
| | self.num_res_blocks = num_res_blocks |
| | self.channel_mult = channel_mult |
| | self.num_attn_blocks = num_attn_blocks |
| | self.attention_resolutions = attention_resolutions |
| | self.num_heads = num_heads |
| | self.num_head_channels = num_head_channels |
| | self.d_context = d_context |
| | self.p_dropout = p_dropout |
| | self.num_classes = num_classes |
| | self.use_extra_film = use_extra_film |
| | self.d_emb = d_emb |
| | self.use_scale_shift_norm = use_scale_shift_norm |
| | self.resblock_updown = resblock_updown |
| |
|
| | time_embed_dim = model_channels * 4 |
| | self.pos_enc = PositionEncoder(model_channels, time_embed_dim) |
| |
|
| | assert ( |
| | num_classes == -1 or use_extra_film is None |
| | ), "You cannot set both num_classes and use_extra_film." |
| |
|
| | if self.num_classes > 0: |
| | |
| | self.label_emb = nn.Embedding(num_classes, time_embed_dim, max_norm=1.0) |
| | elif use_extra_film is not None: |
| | assert ( |
| | d_emb is not None |
| | ), "d_emb must be specified if use_extra_film is not None" |
| | assert use_extra_film in [ |
| | "add", |
| | "concat", |
| | ], f"use_extra_film only supported by add or concat. Your input is {use_extra_film}" |
| | self.use_extra_film = use_extra_film |
| | self.film_emb = ConvNd(dims, d_emb, time_embed_dim, 1) |
| | if use_extra_film == "concat": |
| | time_embed_dim *= 2 |
| |
|
| | |
| | ch = input_ch = int(channel_mult[0] * model_channels) |
| | self.input_blocks = nn.ModuleList( |
| | [UNetSequential(ConvNd(dims, in_channels, ch, 3, padding=1))] |
| | ) |
| | self._feature_size = ch |
| | input_block_chans = [ch] |
| | ds = 1 |
| | for level, mult in enumerate(channel_mult): |
| | for _ in range(num_res_blocks): |
| | layers = [ |
| | ResBlock( |
| | ch, |
| | time_embed_dim, |
| | p_dropout, |
| | out_channels=int(mult * model_channels), |
| | dims=dims, |
| | use_scale_shift_norm=use_scale_shift_norm, |
| | ) |
| | ] |
| | ch = int(mult * model_channels) |
| | if ds in attention_resolutions: |
| | for _ in range(num_attn_blocks): |
| | layers.append( |
| | AttentionBlock( |
| | ch, |
| | num_heads=num_heads, |
| | num_head_channels=num_head_channels, |
| | encoder_channels=d_context, |
| | dims=dims, |
| | h_dim=h_dim // (level + 1), |
| | encoder_hdim=context_hdim, |
| | p_dropout=p_dropout, |
| | ) |
| | ) |
| | self.input_blocks.append(UNetSequential(*layers)) |
| | self._feature_size += ch |
| | input_block_chans.append(ch) |
| | if level != len(channel_mult) - 1: |
| | out_ch = ch |
| | self.input_blocks.append( |
| | UNetSequential( |
| | ResBlock( |
| | ch, |
| | time_embed_dim, |
| | p_dropout, |
| | out_channels=out_ch, |
| | dims=dims, |
| | use_scale_shift_norm=use_scale_shift_norm, |
| | down=True, |
| | ) |
| | if resblock_updown |
| | else Downsample(ch, dims=dims, out_channels=out_ch) |
| | ) |
| | ) |
| | ch = out_ch |
| | input_block_chans.append(ch) |
| | ds *= 2 |
| | self._feature_size += ch |
| |
|
| | |
| | self.middle_block = UNetSequential( |
| | ResBlock( |
| | ch, |
| | time_embed_dim, |
| | p_dropout, |
| | dims=dims, |
| | use_scale_shift_norm=use_scale_shift_norm, |
| | ), |
| | AttentionBlock( |
| | ch, |
| | num_heads=num_heads, |
| | num_head_channels=num_head_channels, |
| | encoder_channels=d_context, |
| | dims=dims, |
| | h_dim=h_dim // (level + 1), |
| | encoder_hdim=context_hdim, |
| | p_dropout=p_dropout, |
| | ), |
| | ResBlock( |
| | ch, |
| | time_embed_dim, |
| | p_dropout, |
| | dims=dims, |
| | use_scale_shift_norm=use_scale_shift_norm, |
| | ), |
| | ) |
| | self._feature_size += ch |
| |
|
| | |
| | self.output_blocks = nn.ModuleList([]) |
| | for level, mult in tuple(enumerate(channel_mult))[::-1]: |
| | for i in range(num_res_blocks + 1): |
| | ich = input_block_chans.pop() |
| | layers = [ |
| | ResBlock( |
| | ch + ich, |
| | time_embed_dim, |
| | p_dropout, |
| | out_channels=int(model_channels * mult), |
| | dims=dims, |
| | use_scale_shift_norm=use_scale_shift_norm, |
| | ) |
| | ] |
| | ch = int(model_channels * mult) |
| | if ds in attention_resolutions: |
| | for _ in range(num_attn_blocks): |
| | layers.append( |
| | AttentionBlock( |
| | ch, |
| | num_heads=num_heads, |
| | num_head_channels=num_head_channels, |
| | encoder_channels=d_context, |
| | dims=dims, |
| | h_dim=h_dim // (level + 1), |
| | encoder_hdim=context_hdim, |
| | p_dropout=p_dropout, |
| | ) |
| | ) |
| | if level and i == num_res_blocks: |
| | out_ch = ch |
| | layers.append( |
| | ResBlock( |
| | ch, |
| | time_embed_dim, |
| | p_dropout, |
| | out_channels=out_ch, |
| | dims=dims, |
| | use_scale_shift_norm=use_scale_shift_norm, |
| | up=True, |
| | ) |
| | if resblock_updown |
| | else Upsample(ch, dims=dims, out_channels=out_ch) |
| | ) |
| | ds //= 2 |
| | self.output_blocks.append(UNetSequential(*layers)) |
| | self._feature_size += ch |
| |
|
| | |
| | self.out = nn.Sequential( |
| | normalization(ch), |
| | nn.SiLU(), |
| | zero_module(ConvNd(dims, input_ch, out_channels, 3, padding=1)), |
| | ) |
| |
|
| | def forward(self, x, timesteps=None, context=None, y=None, **kwargs): |
| | r"""Apply the model to an input batch. |
| | |
| | Args: |
| | x: an [N x C x ...] Tensor of inputs. |
| | timesteps: a 1-D batch of timesteps, i.e. [N]. |
| | context: conditioning Tensor with shape of [N x ``d_context`` x ...] plugged |
| | in via cross attention. |
| | y: an [N] Tensor of labels, if **class-conditional**. |
| | an [N x ``d_emb`` x ...] Tensor if **film-embed conditional**. |
| | |
| | Returns: |
| | an [N x C x ...] Tensor of outputs. |
| | """ |
| | assert (y is None) or ( |
| | (y is not None) |
| | and ((self.num_classes > 0) or (self.use_extra_film is not None)) |
| | ), f"y must be specified if num_classes or use_extra_film is not None. \nGot num_classes: {self.num_classes}\t\nuse_extra_film: {self.use_extra_film}\t\n" |
| |
|
| | hs = [] |
| | emb = self.pos_enc(timesteps) |
| | emb = append_dims(emb, x.dim()) |
| |
|
| | if self.num_classes > 0: |
| | assert y.size() == (x.size(0),) |
| | emb = emb + self.label_emb(y) |
| | elif self.use_extra_film is not None: |
| | assert y.size() == (x.size(0), self.d_emb, *x.size()[2:]) |
| | y = self.film_emb(y) |
| | if self.use_extra_film == "add": |
| | emb = emb + y |
| | elif self.use_extra_film == "concat": |
| | emb = torch.cat([emb, y], dim=1) |
| |
|
| | h = x |
| | for module in self.input_blocks: |
| | h = module(h, emb, context) |
| | hs.append(h) |
| | h = self.middle_block(h, emb, context) |
| | for module in self.output_blocks: |
| | h = torch.cat([h, hs.pop()], dim=1) |
| | h = module(h, emb, context) |
| |
|
| | return self.out(h) |
| |
|
| |
|
| | class UNetSequential(nn.Sequential): |
| | r"""A sequential module that passes embeddings to the children that support it.""" |
| |
|
| | def forward(self, x, emb=None, context=None): |
| | for layer in self: |
| | if isinstance(layer, ResBlock): |
| | x = layer(x, emb) |
| | elif isinstance(layer, AttentionBlock): |
| | x = layer(x, context) |
| | else: |
| | x = layer(x) |
| | return x |
| |
|