Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Optional, Tuple, Union | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from einops import rearrange, repeat | |
| from diffusers.models.embeddings import ( | |
| GaussianFourierProjection, | |
| TimestepEmbedding, | |
| Timesteps, | |
| ) | |
| from diffusers.models.unet_2d_blocks import ( | |
| DownBlock2D, | |
| UpBlock2D, | |
| ) | |
| from diffusers.models.resnet import Downsample2D, Upsample2D | |
| from diffusers.configuration_utils import ConfigMixin | |
| from diffusers.models.modeling_utils import ModelMixin | |
| # xFormers imports | |
| try: | |
| from xformers.ops import memory_efficient_attention, unbind | |
| XFORMERS_AVAILABLE = True | |
| except ImportError: | |
| print("xFormers not available") | |
| XFORMERS_AVAILABLE = False | |
| def modulate(x, shift, scale): | |
| return x * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
| def pair(t): | |
| return t if isinstance(t, tuple) else (t, t) | |
| def build_2d_sincos_posemb(h, w, embed_dim=1024, temperature=10000.): | |
| """Sine-cosine positional embeddings as used in MoCo-v3 | |
| Returns positional embedding of shape [B, H, W, D] | |
| """ | |
| grid_w = torch.arange(w, dtype=torch.float32) | |
| grid_h = torch.arange(h, dtype=torch.float32) | |
| grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') | |
| assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' | |
| pos_dim = embed_dim // 4 | |
| omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim | |
| omega = 1. / (temperature ** omega) | |
| out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) | |
| out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) | |
| pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] | |
| pos_emb = rearrange(pos_emb, 'b (h w) d -> b d h w', h=h, w=w) | |
| return pos_emb | |
| def drop_path(x, drop_prob: float = 0., training: bool = False): | |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |
| 'survival rate' as the argument. | |
| """ | |
| if drop_prob == 0. or not training: | |
| return x | |
| keep_prob = 1 - drop_prob | |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | |
| random_tensor.floor_() # binarize | |
| output = x.div(keep_prob) * random_tensor | |
| return output | |
| class DropPath(nn.Module): | |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
| """ | |
| def __init__(self, drop_prob=None): | |
| super(DropPath, self).__init__() | |
| self.drop_prob = drop_prob | |
| def forward(self, x): | |
| return drop_path(x, self.drop_prob, self.training) | |
| def extra_repr(self) -> str: | |
| return 'p={}'.format(self.drop_prob) | |
| class Mlp(nn.Module): | |
| def __init__(self, in_features, temb_dim=None, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| self.hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear(in_features, self.hidden_features) | |
| self.act = act_layer() | |
| self.fc2 = nn.Linear(self.hidden_features, out_features) | |
| self.drop = nn.Dropout(drop) | |
| if temb_dim is not None: | |
| self.adaLN_modulation = nn.Linear(temb_dim, 2 * self.hidden_features) | |
| def forward(self, x, temb=None): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| # Shift and scale using time emb (see https://arxiv.org/abs/2301.11093) | |
| if hasattr(self, 'adaLN_modulation'): | |
| shift, scale = self.adaLN_modulation(F.silu(temb)).chunk(2, dim=-1) | |
| x = modulate(x, shift, scale) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class Attention(nn.Module): | |
| def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = head_dim ** -0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x, mask=None): | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
| if XFORMERS_AVAILABLE: | |
| q, k, v = unbind(qkv, 2) | |
| if mask is not None: | |
| # Wherever mask is True it becomes -infinity, otherwise 0 | |
| mask = mask.to(q.dtype) * -torch.finfo(q.dtype).max | |
| x = memory_efficient_attention(q, k, v, attn_bias=mask) | |
| x = x.reshape([B, N, C]) | |
| else: | |
| qkv = qkv.permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| if mask is not None: | |
| mask = mask.unsqueeze(1) | |
| attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max) | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class CrossAttention(nn.Module): | |
| def __init__(self, dim, dim_context=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): | |
| super().__init__() | |
| dim_context = dim_context or dim | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = head_dim ** -0.5 | |
| self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.kv = nn.Linear(dim_context, dim * 2, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x, context, mask=None): | |
| B, N, C = x.shape | |
| _, M, _ = context.shape | |
| q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads) | |
| kv = self.kv(context).reshape(B, M, 2, self.num_heads, C // self.num_heads) | |
| if XFORMERS_AVAILABLE: | |
| k, v = unbind(kv, 2) | |
| if mask is not None: | |
| # Wherever mask is True it becomes -infinity, otherwise 0 | |
| mask = mask.to(q.dtype) * -torch.finfo(q.dtype).max | |
| x = memory_efficient_attention(q, k, v, attn_bias=mask) | |
| x = x.reshape([B, N, C]) | |
| else: | |
| q = q.permute(0, 2, 1, 3) | |
| kv = kv.permute(2, 0, 3, 1, 4) | |
| k, v = kv[0], kv[1] | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| if mask is not None: | |
| mask = rearrange(mask, "b n m -> b 1 n m") | |
| attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max) | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, dim, num_heads, temb_dim=None, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., | |
| drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, temb_in_mlp=False, temb_after_norm=True, temb_gate=True): | |
| super().__init__() | |
| self.norm1 = norm_layer(dim) | |
| self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = Mlp(in_features=dim, temb_dim=temb_dim if temb_in_mlp else None, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
| if temb_after_norm and temb_dim is not None: | |
| # adaLN modulation (see https://arxiv.org/abs/2212.09748) | |
| self.adaLN_modulation = nn.Linear(temb_dim, 4 * dim) | |
| if temb_gate and temb_dim is not None: | |
| # adaLN-Zero gate (see https://arxiv.org/abs/2212.09748) | |
| self.adaLN_gate = nn.Linear(temb_dim, 2 * dim) | |
| nn.init.zeros_(self.adaLN_gate.weight) | |
| nn.init.zeros_(self.adaLN_gate.bias) | |
| self.skip_linear = nn.Linear(2*dim, dim) if skip else None | |
| def forward(self, x, temb=None, mask=None, skip_connection=None): | |
| gate_msa, gate_mlp = self.adaLN_gate(F.silu(temb)).unsqueeze(1).chunk(2, dim=-1) if hasattr(self, 'adaLN_gate') else (1.0, 1.0) | |
| shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(F.silu(temb)).chunk(4, dim=-1) if hasattr(self, 'adaLN_modulation') else 4*[0.0] | |
| if self.skip_linear is not None: | |
| x = self.skip_linear(torch.cat([x, skip_connection], dim=-1)) | |
| x = x + gate_msa * self.drop_path(self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask)) | |
| x = x + gate_mlp * self.drop_path(self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), temb)) | |
| return x | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, dim, num_heads, temb_dim=None, dim_context=None, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., | |
| drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, temb_in_mlp=False, temb_after_norm=True, temb_gate=True): | |
| super().__init__() | |
| dim_context = dim_context or dim | |
| self.norm1 = norm_layer(dim) | |
| self.self_attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
| self.cross_attn = CrossAttention(dim, dim_context=dim_context, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
| self.query_norm = norm_layer(dim) | |
| self.context_norm = norm_layer(dim_context) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = Mlp(in_features=dim, temb_dim=temb_dim if temb_in_mlp else None, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
| if temb_after_norm and temb_dim is not None: | |
| # adaLN modulation (see https://arxiv.org/abs/2212.09748) | |
| self.adaLN_modulation = nn.Linear(temb_dim, 6 * dim) | |
| if temb_gate and temb_dim is not None: | |
| # adaLN-Zero gate (see https://arxiv.org/abs/2212.09748) | |
| self.adaLN_gate = nn.Linear(temb_dim, 3 * dim) | |
| nn.init.zeros_(self.adaLN_gate.weight) | |
| nn.init.zeros_(self.adaLN_gate.bias) | |
| self.skip_linear = nn.Linear(2*dim, dim) if skip else None | |
| def forward(self, x, context, temb=None, sa_mask=None, xa_mask=None, skip_connection=None): | |
| gate_msa, gate_mxa, gate_mlp = self.adaLN_gate(F.silu(temb)).unsqueeze(1).chunk(3, dim=-1) if hasattr(self, 'adaLN_gate') else (1.0, 1.0, 1.0) | |
| shift_msa, scale_msa, shift_mxa, scale_mxa, shift_mlp, scale_mlp = self.adaLN_modulation(F.silu(temb)).chunk(6, dim=-1) if hasattr(self, 'adaLN_modulation') else 6*[0.0] | |
| if self.skip_linear is not None: | |
| x = self.skip_linear(torch.cat([x, skip_connection], dim=-1)) | |
| x = x + gate_msa * self.drop_path(self.self_attn(modulate(self.norm1(x), shift_msa, scale_msa), sa_mask)) | |
| x = x + gate_mxa * self.drop_path(self.cross_attn(modulate(self.query_norm(x), shift_mxa, scale_mxa), self.context_norm(context), xa_mask)) | |
| x = x + gate_mlp * self.drop_path(self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), temb)) | |
| return x | |
| class TransformerConcatCond(nn.Module): | |
| """UViT Transformer bottleneck that concatenates the condition to the input. | |
| Args: | |
| unet_dim: Number of channels in the last UNet down block. | |
| cond_dim: Number of channels in the condition. | |
| mid_layers: Number of Transformer layers. | |
| mid_num_heads: Number of attention heads. | |
| mid_dim: Transformer dimension. | |
| mid_mlp_ratio: Ratio of MLP hidden dim to Transformer dim. | |
| mid_qkv_bias: Whether to add bias to the query, key, and value projections. | |
| mid_drop_rate: Dropout rate. | |
| mid_attn_drop_rate: Attention dropout rate. | |
| mid_drop_path_rate: Stochastic depth rate. | |
| time_embed_dim: Dimension of the time embedding. | |
| hw_posemb: Size (side) of the 2D positional embedding. | |
| use_long_skip: Whether to use long skip connections. | |
| See https://arxiv.org/abs/2209.12152 for more details. | |
| """ | |
| def __init__( | |
| self, | |
| unet_dim: int = 1024, | |
| cond_dim: int = 32, | |
| mid_layers: int = 12, | |
| mid_num_heads: int = 12, | |
| mid_dim: int = 768, | |
| mid_mlp_ratio: int = 4, | |
| mid_qkv_bias: bool = True, | |
| mid_drop_rate: float = 0.0, | |
| mid_attn_drop_rate: float = 0.0, | |
| mid_drop_path_rate: float = 0.0, | |
| time_embed_dim: int = 512, | |
| hw_posemb: int = 16, | |
| use_long_skip: bool = False, | |
| ): | |
| super().__init__() | |
| self.mid_pos_emb = build_2d_sincos_posemb(h=hw_posemb, w=hw_posemb, embed_dim=mid_dim) | |
| self.mid_pos_emb = nn.Parameter(self.mid_pos_emb, requires_grad=False) | |
| self.use_long_skip = use_long_skip | |
| if use_long_skip: | |
| assert mid_layers % 2 == 1, 'mid_layers must be odd when using long skip connection' | |
| dpr = [x.item() for x in torch.linspace(0, mid_drop_path_rate, mid_layers)] # stochastic depth decay rule | |
| self.mid_block = nn.ModuleList([ | |
| Block(dim=mid_dim, temb_dim=time_embed_dim, num_heads=mid_num_heads, mlp_ratio=mid_mlp_ratio, qkv_bias=mid_qkv_bias, | |
| drop=mid_drop_rate, attn_drop=mid_attn_drop_rate, drop_path=dpr[i], skip=i > mid_layers//2 and use_long_skip) | |
| for i in range(mid_layers) | |
| ]) | |
| self.mid_cond_proj = nn.Linear(cond_dim, mid_dim) | |
| self.mid_proj_in = nn.Linear(unet_dim, mid_dim) | |
| self.mid_proj_out = nn.Linear(mid_dim, unet_dim) | |
| self.mask_token = nn.Parameter(torch.zeros(mid_dim), requires_grad=True) | |
| def forward(self, | |
| x: torch.Tensor, | |
| temb: torch.Tensor, | |
| cond: torch.Tensor, | |
| cond_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: | |
| """TransformerConcatCond forward pass. | |
| Args: | |
| x: UNet features from the last down block of shape [B, C_mid, H_mid, W_mid]. | |
| temb: Time embedding of shape [B, temb_dim]. | |
| cond: Condition of shape [B, cond_dim, H_cond, W_cond]. If H_cond and W_cond are | |
| different from H_mid and W_mid, cond is interpolated to match the spatial size | |
| of x. | |
| cond_mask: Condition mask of shape [B, H_mid, W_mid]. If a mask is | |
| defined, replaces masked-out tokens by a learnable mask-token. | |
| Wherever cond_mask is True, the condition gets replaced by the mask token. | |
| Returns: | |
| Features of shape [B, C_mid, H_mid, W_mid] to pass to the UNet up blocks. | |
| """ | |
| B, C_mid, H_mid, W_mid = x.shape | |
| # Rearrange and proj UNet features to sequence of tokens | |
| x = rearrange(x, 'b d h w -> b (h w) d') | |
| x = self.mid_proj_in(x) | |
| # Rearrange and proj conditioning to sequence of tokens | |
| cond = F.interpolate(cond, (H_mid, W_mid)) # Interpolate if necessary | |
| cond = rearrange(cond, 'b d h w -> b (h w) d') | |
| cond = self.mid_cond_proj(cond) | |
| # If a mask is defined, replace masked-out tokens by a learnable mask-token | |
| # Wherever cond_mask is True, the condition gets replaced by the mask token | |
| if cond_mask is not None: | |
| cond_mask = F.interpolate(cond_mask.unsqueeze(1).float(), (H_mid, W_mid), mode='nearest') > 0.5 | |
| cond_mask = rearrange(cond_mask, 'b 1 h w -> b (h w)') | |
| cond[cond_mask] = self.mask_token.to(dtype=cond.dtype) | |
| x = x + cond | |
| # Interpolate and rearrange positional embedding to sequence of tokens | |
| mid_pos_emb = F.interpolate(self.mid_pos_emb, (H_mid, W_mid), mode='bicubic', align_corners=False) | |
| mid_pos_emb = rearrange(mid_pos_emb, 'b d h w -> b (h w) d') | |
| x = x + mid_pos_emb | |
| # Transformer forward pass with or without long skip connections | |
| if not self.use_long_skip: | |
| for blk in self.mid_block: | |
| x = blk(x, temb) | |
| else: | |
| skip_connections = [] | |
| num_skips = len(self.mid_block) // 2 | |
| for blk in self.mid_block[:num_skips]: | |
| x = blk(x, temb) | |
| skip_connections.append(x) | |
| x = self.mid_block[num_skips](x, temb) | |
| for blk in self.mid_block[num_skips+1:]: | |
| x = blk(x, temb, skip_connection=skip_connections.pop()) | |
| x = self.mid_proj_out(x) # Project Transformer output back to UNet channels | |
| x = rearrange(x, 'b (h w) d -> b d h w', h=H_mid, w=W_mid) # Rearrange Transformer tokens to a spatial feature map for conv layers | |
| return x | |
| class TransformerXattnCond(nn.Module): | |
| """UViT Transformer bottleneck that incroporates the condition via cross-attention. | |
| Args: | |
| unet_dim: Number of channels in the last UNet down block. | |
| cond_dim: Number of channels in the condition. | |
| mid_layers: Number of Transformer layers. | |
| mid_num_heads: Number of attention heads. | |
| mid_dim: Transformer dimension. | |
| mid_mlp_ratio: Ratio of MLP hidden dim to Transformer dim. | |
| mid_qkv_bias: Whether to add bias to the query, key, and value projections. | |
| mid_drop_rate: Dropout rate. | |
| mid_attn_drop_rate: Attention dropout rate. | |
| mid_drop_path_rate: Stochastic depth rate. | |
| time_embed_dim: Dimension of the time embedding. | |
| hw_posemb: Size (side) of the 2D positional embedding. | |
| use_long_skip: Whether to use long skip connections. | |
| See https://arxiv.org/abs/2209.12152 for more details. | |
| """ | |
| def __init__( | |
| self, | |
| unet_dim: int = 1024, | |
| cond_dim: int = 32, | |
| mid_layers: int = 12, | |
| mid_num_heads: int = 12, | |
| mid_dim: int = 768, | |
| mid_mlp_ratio: int = 4, | |
| mid_qkv_bias: bool = True, | |
| mid_drop_rate: float = 0.0, | |
| mid_attn_drop_rate: float = 0.0, | |
| mid_drop_path_rate: float = 0.0, | |
| time_embed_dim: int = 512, | |
| hw_posemb: int = 16, | |
| use_long_skip: bool = False, | |
| ): | |
| super().__init__() | |
| self.mid_pos_emb = build_2d_sincos_posemb(h=hw_posemb, w=hw_posemb, embed_dim=mid_dim) | |
| self.mid_pos_emb = nn.Parameter(self.mid_pos_emb, requires_grad=False) | |
| self.use_long_skip = use_long_skip | |
| if use_long_skip: | |
| assert mid_layers % 2 == 1, 'mid_layers must be odd when using long skip connection' | |
| dpr = [x.item() for x in torch.linspace(0, mid_drop_path_rate, mid_layers)] # stochastic depth decay rule | |
| self.mid_block = nn.ModuleList([ | |
| DecoderBlock( | |
| dim=mid_dim, temb_dim=time_embed_dim, num_heads=mid_num_heads, dim_context=cond_dim, | |
| mlp_ratio=mid_mlp_ratio, qkv_bias=mid_qkv_bias, drop=mid_drop_rate, | |
| attn_drop=mid_attn_drop_rate, drop_path=dpr[i], | |
| skip=i > mid_layers//2 and use_long_skip | |
| ) | |
| for i in range(mid_layers) | |
| ]) | |
| self.mid_proj_in = nn.Linear(unet_dim, mid_dim) | |
| self.mid_proj_out = nn.Linear(mid_dim, unet_dim) | |
| def forward(self, | |
| x: torch.Tensor, | |
| temb: torch.Tensor, | |
| cond: torch.Tensor, | |
| cond_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: | |
| """TransformerXattnCond forward pass. | |
| Args: | |
| x: UNet features from the last down block of shape [B, C_mid, H_mid, W_mid]. | |
| temb: Time embedding of shape [B, temb_dim]. | |
| cond: Condition of shape [B, cond_dim, H_cond, W_cond]. | |
| cond_mask: Condition cross-attention mask of shape [B, H_cond, W_cond]. | |
| If a mask is defined, wherever cond_mask is True, the condition at that | |
| spatial location is not cross-attended to. | |
| Returns: | |
| Features of shape [B, C_mid, H_mid, W_mid] to pass to the UNet up blocks. | |
| """ | |
| B, C_mid, H_mid, W_mid = x.shape | |
| # Rearrange and proj UNet features to sequence of tokens | |
| x = rearrange(x, 'b d h w -> b (h w) d') | |
| x = self.mid_proj_in(x) | |
| # Rearrange conditioning to sequence of tokens | |
| cond = rearrange(cond, 'b d h w -> b (h w) d') | |
| # Interpolate and rearrange positional embedding to sequence of tokens | |
| mid_pos_emb = F.interpolate(self.mid_pos_emb, (H_mid, W_mid), mode='nearest') | |
| mid_pos_emb = rearrange(mid_pos_emb, 'b d h w -> b (h w) d') | |
| # Add UNet mid-block features and positional embedding | |
| x = x + mid_pos_emb | |
| # Prepare the conditioning cross-attention mask | |
| xa_mask = repeat(cond_mask, 'b h w -> b n (h w)', n=x.shape[1]) if cond_mask is not None else None | |
| # Transformer forward pass with or without long skip connections. | |
| # In each layer, cross-attend to the conditioning. | |
| if not self.use_long_skip: | |
| for blk in self.mid_block: | |
| x = blk(x, cond, temb, xa_mask=xa_mask) | |
| else: | |
| skip_connections = [] | |
| num_skips = len(self.mid_block) // 2 | |
| for blk in self.mid_block[:num_skips]: | |
| x = blk(x, cond, temb, xa_mask=xa_mask) | |
| skip_connections.append(x) | |
| x = self.mid_block[num_skips](x, cond, temb, xa_mask=xa_mask) | |
| for blk in self.mid_block[num_skips+1:]: | |
| x = blk(x, cond, temb, xa_mask=xa_mask, skip_connection=skip_connections.pop()) | |
| x = self.mid_proj_out(x) # Project Transformer output back to UNet channels | |
| x = rearrange(x, 'b (h w) d -> b d h w', h=H_mid, w=W_mid) # Rearrange Transformer tokens to a spatial feature map for conv layers | |
| return x | |
| class UViT(ModelMixin, ConfigMixin): | |
| """UViT model = Conditional UNet with Transformer bottleneck | |
| blocks and optionalpatching. | |
| See https://arxiv.org/abs/2301.11093 for more details. | |
| Args: | |
| sample_size: Size of the input images. | |
| in_channels: Number of input channels. | |
| out_channels: Number of output channels. | |
| patch_size: Size of the input patching operation. | |
| See https://arxiv.org/abs/2207.04316 for more details. | |
| block_out_channels: Number of output channels of each UNet ResNet-block. | |
| layers_per_block: Number of ResNet blocks per UNet block. | |
| downsample_before_mid: Whether to downsample before the Transformer bottleneck. | |
| mid_layers: Number of Transformer blocks. | |
| mid_num_heads: Number of attention heads. | |
| mid_dim: Transformer dimension. | |
| mid_mlp_ratio: Transformer MLP ratio. | |
| mid_qkv_bias: Whether to use bias in the Transformer QKV projection. | |
| mid_drop_rate: Dropout rate of the Transformer MLP and attention output projection. | |
| mid_attn_drop_rate: Dropout rate of the Transformer attention. | |
| mid_drop_path_rate: Stochastic depth rate of the Transformer blocks. | |
| mid_hw_posemb: Size (side) of the Transformer positional embedding. | |
| mid_use_long_skip: Whether to use long skip connections in the Transformer blocks. | |
| See https://arxiv.org/abs/2209.12152 for more details. | |
| cond_dim: Dimension of the conditioning vector. | |
| cond_type: Type of conditioning. | |
| 'concat' for concatenation, 'xattn' for cross-attention. | |
| downsample_padding: Padding of the UNet downsampling convolutions. | |
| act_fn: Activation function. | |
| norm_num_groups: Number of groups in the UNet ResNet-block normalization. | |
| norm_eps: Epsilon of the UNet ResNet-block normalization. | |
| resnet_time_scale_shift: Time scale shift of the UNet ResNet-blocks. | |
| resnet_out_scale_factor: Output scale factor of the UNet ResNet-blocks. | |
| time_embedding_type: Type of the time embedding. | |
| 'positional' for positional, 'fourier' for Fourier. | |
| time_embedding_dim: Dimension of the time embedding. | |
| time_embedding_act_fn: Activation function of the time embedding. | |
| timestep_post_act: Activation function after the time embedding. | |
| time_cond_proj_dim: Dimension of the optional conditioning projection. | |
| flip_sin_to_cos: Whether to flip the sine to cosine in the time embedding. | |
| freq_shift: Frequency shift of the time embedding. | |
| res_embedding: Whether to perform original resolution conditioning. | |
| See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
| """ | |
| def __init__(self, | |
| # UNet settings | |
| sample_size: Optional[int] = None, | |
| in_channels: int = 3, | |
| out_channels: int = 3, | |
| patch_size: int = 4, | |
| block_out_channels: Tuple[int] = (128, 256, 512), | |
| layers_per_block: Union[int, Tuple[int]] = 2, | |
| downsample_before_mid: bool = False, | |
| # Mid-block Transformer settings | |
| mid_layers: int = 12, | |
| mid_num_heads: int = 12, | |
| mid_dim: int = 768, | |
| mid_mlp_ratio: int = 4, | |
| mid_qkv_bias: bool = True, | |
| mid_drop_rate: float = 0.0, | |
| mid_attn_drop_rate: float = 0.0, | |
| mid_drop_path_rate: float = 0.0, | |
| mid_hw_posemb: int = 32, | |
| mid_use_long_skip: bool = False, | |
| # Conditioning settings | |
| cond_dim: int = 32, | |
| cond_type: str = 'concat', | |
| # ResNet blocks settings | |
| downsample_padding: int = 1, | |
| act_fn: str = "silu", | |
| norm_num_groups: Optional[int] = 32, | |
| norm_eps: float = 1e-5, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_out_scale_factor: int = 1.0, | |
| # Time embedding settings | |
| time_embedding_type: str = "positional", | |
| time_embedding_dim: Optional[int] = None, | |
| time_embedding_act_fn: Optional[str] = None, | |
| timestep_post_act: Optional[str] = None, | |
| time_cond_proj_dim: Optional[int] = None, | |
| flip_sin_to_cos: bool = True, | |
| freq_shift: int = 0, | |
| # Original resolution embedding settings | |
| res_embedding: bool = False): | |
| super().__init__() | |
| self.sample_size = sample_size | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.mid_dim = block_out_channels[-1] | |
| self.res_embedding = res_embedding | |
| # input patching | |
| self.conv_in = nn.Conv2d( | |
| in_channels, block_out_channels[0], kernel_size=patch_size, padding=0, stride=patch_size | |
| ) | |
| # time | |
| if time_embedding_type == "fourier": | |
| time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 | |
| if time_embed_dim % 2 != 0: | |
| raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") | |
| self.time_proj = GaussianFourierProjection( | |
| time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
| ) | |
| timestep_input_dim = time_embed_dim | |
| elif time_embedding_type == "positional": | |
| time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 | |
| self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
| timestep_input_dim = block_out_channels[0] | |
| else: | |
| raise ValueError( | |
| f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." | |
| ) | |
| self.time_embedding = TimestepEmbedding( | |
| timestep_input_dim, | |
| time_embed_dim, | |
| act_fn=act_fn, | |
| post_act_fn=timestep_post_act, | |
| cond_proj_dim=time_cond_proj_dim, | |
| ) | |
| if time_embedding_act_fn is None: | |
| self.time_embed_act = None | |
| elif time_embedding_act_fn == "swish": | |
| self.time_embed_act = lambda x: F.silu(x) | |
| elif time_embedding_act_fn == "mish": | |
| self.time_embed_act = nn.Mish() | |
| elif time_embedding_act_fn == "silu": | |
| self.time_embed_act = nn.SiLU() | |
| elif time_embedding_act_fn == "gelu": | |
| self.time_embed_act = nn.GELU() | |
| else: | |
| raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") | |
| # original resolution embedding | |
| if res_embedding: | |
| if time_embedding_type == "fourier": | |
| self.h_proj = GaussianFourierProjection( | |
| time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
| ) | |
| self.w_proj = GaussianFourierProjection( | |
| time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
| ) | |
| elif time_embedding_type == "positional": | |
| self.height_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
| self.width_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
| self.height_embedding = TimestepEmbedding( | |
| timestep_input_dim, time_embed_dim, act_fn=act_fn, | |
| post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, | |
| ) | |
| self.width_embedding = TimestepEmbedding( | |
| timestep_input_dim, time_embed_dim, act_fn=act_fn, | |
| post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, | |
| ) | |
| self.down_blocks = nn.ModuleList([]) | |
| self.up_blocks = nn.ModuleList([]) | |
| if isinstance(layers_per_block, int): | |
| layers_per_block = [layers_per_block] * len(block_out_channels) | |
| # down | |
| output_channel = block_out_channels[0] | |
| for i in range(len(block_out_channels)): | |
| input_channel = output_channel | |
| output_channel = block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| down_block = DownBlock2D( | |
| num_layers=layers_per_block[i], | |
| in_channels=input_channel, | |
| out_channels=output_channel, | |
| temb_channels=time_embed_dim, | |
| add_downsample=not is_final_block, | |
| resnet_eps=norm_eps, | |
| resnet_act_fn=act_fn, | |
| resnet_groups=norm_num_groups, | |
| downsample_padding=downsample_padding, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| output_scale_factor=resnet_out_scale_factor, | |
| ) | |
| self.down_blocks.append(down_block) | |
| if downsample_before_mid: | |
| self.downsample_mid = Downsample2D(self.mid_dim, use_conv=True, out_channels=self.mid_dim) | |
| # mid | |
| if cond_type == 'concat': | |
| self.mid_block = TransformerConcatCond( | |
| unet_dim=self.mid_dim, cond_dim=cond_dim, mid_layers=mid_layers, mid_num_heads=mid_num_heads, | |
| mid_dim=mid_dim, mid_mlp_ratio=mid_mlp_ratio, mid_qkv_bias=mid_qkv_bias, | |
| mid_drop_rate=mid_drop_rate, mid_attn_drop_rate=mid_attn_drop_rate, mid_drop_path_rate=mid_drop_path_rate, | |
| time_embed_dim=time_embed_dim, hw_posemb=mid_hw_posemb, use_long_skip=mid_use_long_skip, | |
| ) | |
| elif cond_type == 'xattn': | |
| self.mid_block = TransformerXattnCond( | |
| unet_dim=self.mid_dim, cond_dim=cond_dim, mid_layers=mid_layers, mid_num_heads=mid_num_heads, | |
| mid_dim=mid_dim, mid_mlp_ratio=mid_mlp_ratio, mid_qkv_bias=mid_qkv_bias, | |
| mid_drop_rate=mid_drop_rate, mid_attn_drop_rate=mid_attn_drop_rate, mid_drop_path_rate=mid_drop_path_rate, | |
| time_embed_dim=time_embed_dim, hw_posemb=mid_hw_posemb, use_long_skip=mid_use_long_skip, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported cond_type: {cond_type}") | |
| # count how many layers upsample the images | |
| self.num_upsamplers = 0 | |
| # up | |
| if downsample_before_mid: | |
| self.upsample_mid = Upsample2D(self.mid_dim, use_conv=True, out_channels=self.mid_dim) | |
| reversed_block_out_channels = list(reversed(block_out_channels)) | |
| reversed_layers_per_block = list(reversed(layers_per_block)) | |
| output_channel = reversed_block_out_channels[0] | |
| for i in range(len(reversed_block_out_channels)): | |
| is_final_block = i == len(block_out_channels) - 1 | |
| prev_output_channel = output_channel | |
| output_channel = reversed_block_out_channels[i] | |
| input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
| # add upsample block for all BUT final layer | |
| if not is_final_block: | |
| add_upsample = True | |
| self.num_upsamplers += 1 | |
| else: | |
| add_upsample = False | |
| up_block = UpBlock2D( | |
| num_layers=reversed_layers_per_block[i] + 1, | |
| in_channels=input_channel, | |
| out_channels=output_channel, | |
| prev_output_channel=prev_output_channel, | |
| temb_channels=time_embed_dim, | |
| add_upsample=add_upsample, | |
| resnet_eps=norm_eps, | |
| resnet_act_fn=act_fn, | |
| resnet_groups=norm_num_groups, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| output_scale_factor=resnet_out_scale_factor, | |
| ) | |
| self.up_blocks.append(up_block) | |
| prev_output_channel = output_channel | |
| # out | |
| if norm_num_groups is not None: | |
| self.conv_norm_out = nn.GroupNorm( | |
| num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps | |
| ) | |
| if act_fn == "swish": | |
| self.conv_act = lambda x: F.silu(x) | |
| elif act_fn == "mish": | |
| self.conv_act = nn.Mish() | |
| elif act_fn == "silu": | |
| self.conv_act = nn.SiLU() | |
| elif act_fn == "gelu": | |
| self.conv_act = nn.GELU() | |
| else: | |
| raise ValueError(f"Unsupported activation function: {act_fn}") | |
| else: | |
| self.conv_norm_out = None | |
| self.conv_act = None | |
| self.conv_out = nn.ConvTranspose2d( | |
| block_out_channels[0], out_channels, kernel_size=patch_size, stride=patch_size | |
| ) | |
| self.init_weights() | |
| def init_weights(self) -> None: | |
| """Weight initialization following MAE's initialization scheme""" | |
| for name, m in self.named_modules(): | |
| # Handle already zero-init gates | |
| if "adaLN_gate" in name: | |
| continue | |
| # Handle ResNet gates that were not initialized by diffusers | |
| if "conv2" in name: | |
| nn.init.zeros_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| # Linear | |
| elif isinstance(m, nn.Linear): | |
| if 'qkv' in name: | |
| # treat the weights of Q, K, V separately | |
| val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) | |
| nn.init.uniform_(m.weight, -val, val) | |
| elif 'kv' in name: | |
| # treat the weights of K, V separately | |
| val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1])) | |
| nn.init.uniform_(m.weight, -val, val) | |
| else: | |
| nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| # LayerNorm | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| # Embedding | |
| elif isinstance(m, nn.Embedding): | |
| nn.init.normal_(m.weight, std=self.init_std) | |
| # Conv2d | |
| elif isinstance(m, nn.Conv2d): | |
| if '.conv_in' in name or '.conv_out' in name: | |
| # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) | |
| w = m.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| def forward( | |
| self, | |
| sample: torch.FloatTensor, | |
| timestep: Union[torch.Tensor, float, int], | |
| condition: torch.Tensor, | |
| cond_mask: Optional[torch.Tensor] = None, | |
| timestep_cond: Optional[torch.Tensor] = None, | |
| orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| """UViT forward pass. | |
| Args: | |
| sample: Noisy image of shape (B, C, H, W). | |
| timestep: Timestep(s) of the current batch. | |
| condition: Conditioning tensor of shape (B, C_cond, H_cond, W_cond). When concatenating | |
| the condition, it is interpolated to the resolution of the transformer (H_mid, W_mid). | |
| cond_mask: Mask tensor of shape (B, H_mid, W_mid) when concatenating the condition | |
| to the transformer, and (B, H_cond, W_cond) when using cross-attention. True for | |
| masked out / ignored regions. | |
| timestep_cond: Optional conditioning to add to the timestep embedding. | |
| orig_res: The original resolution of the image to condition the diffusion on. Ignored if None. | |
| See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
| Returns: | |
| Diffusion objective target image of shape (B, C, H, W). | |
| """ | |
| # By default samples have to be AT least a multiple of the overall upsampling factor. | |
| # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). | |
| # However, the upsampling interpolation output size can be forced to fit any upsampling size | |
| # on the fly if necessary. | |
| default_overall_up_factor = 2**self.num_upsamplers | |
| # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
| forward_upsample_size = False | |
| upsample_size = None | |
| if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): | |
| forward_upsample_size = True | |
| # 1. time | |
| timesteps = timestep | |
| is_mps = sample.device.type == "mps" | |
| if not torch.is_tensor(timesteps): | |
| # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
| # This would be a good case for the `match` statement (Python 3.10+) | |
| if isinstance(timestep, float): | |
| dtype = torch.float32 if is_mps else torch.float64 | |
| else: | |
| dtype = torch.int32 if is_mps else torch.int64 | |
| timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
| elif len(timesteps.shape) == 0: | |
| timesteps = timesteps[None].to(sample.device) | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timesteps = timesteps.expand(sample.shape[0]) | |
| t_emb = self.time_proj(timesteps) | |
| # `Timesteps` does not contain any weights and will always return f32 tensors | |
| # but time_embedding might actually be running in fp16. so we need to cast here. | |
| # there might be better ways to encapsulate this. | |
| t_emb = t_emb.to(dtype=sample.dtype) | |
| emb = self.time_embedding(t_emb, timestep_cond) | |
| # 1.5 original resolution conditioning (see SDXL paper) | |
| if orig_res is not None and self.res_embedding: | |
| if not torch.is_tensor(orig_res): | |
| h_orig, w_orig = orig_res | |
| dtype = torch.int32 if is_mps else torch.int64 | |
| h_orig = torch.tensor([h_orig], dtype=dtype, device=sample.device).expand(sample.shape[0]) | |
| w_orig = torch.tensor([w_orig], dtype=dtype, device=sample.device).expand(sample.shape[0]) | |
| else: | |
| h_orig, w_orig = orig_res[:,0], orig_res[:,1] | |
| h_emb = self.height_proj(h_orig).to(dtype=sample.dtype) | |
| w_emb = self.width_proj(w_orig).to(dtype=sample.dtype) | |
| emb = emb + self.height_embedding(h_emb) | |
| emb = emb + self.width_embedding(w_emb) | |
| if self.time_embed_act is not None: | |
| emb = self.time_embed_act(emb) | |
| # 2. pre-process | |
| sample = self.conv_in(sample) | |
| # 3. down | |
| down_block_res_samples = (sample,) | |
| for downsample_block in self.down_blocks: | |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
| down_block_res_samples += res_samples | |
| if hasattr(self, 'downsample_mid'): | |
| sample = self.downsample_mid(sample) | |
| # 4. mid | |
| sample = self.mid_block(sample, emb, condition, cond_mask) | |
| # 5. up | |
| if hasattr(self, 'upsample_mid'): | |
| sample = self.upsample_mid(sample) | |
| for i, upsample_block in enumerate(self.up_blocks): | |
| is_final_block = i == len(self.up_blocks) - 1 | |
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
| # if we have not reached the final block and need to forward the | |
| # upsample size, we do it here | |
| if not is_final_block and forward_upsample_size: | |
| upsample_size = down_block_res_samples[-1].shape[2:] | |
| sample = upsample_block( | |
| hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size | |
| ) | |
| # 6. post-process | |
| if self.conv_norm_out: | |
| sample = self.conv_norm_out(sample) | |
| sample = self.conv_act(sample) | |
| sample = self.conv_out(sample) | |
| return sample | |
| def uvit_b_p4_f16(**kwargs): | |
| return UViT( | |
| patch_size=4, | |
| block_out_channels=(128, 256), | |
| layers_per_block=2, | |
| downsample_before_mid=True, | |
| mid_layers=12, | |
| mid_num_heads=12, | |
| mid_dim=768, | |
| mid_mlp_ratio=4, | |
| mid_qkv_bias=True, | |
| **kwargs | |
| ) | |
| def uvit_l_p4_f16(**kwargs): | |
| return UViT( | |
| patch_size=4, | |
| block_out_channels=(128, 256), | |
| layers_per_block=2, | |
| downsample_before_mid=True, | |
| mid_layers=24, | |
| mid_num_heads=16, | |
| mid_dim=1024, | |
| mid_mlp_ratio=4, | |
| mid_qkv_bias=True, | |
| **kwargs | |
| ) | |
| def uvit_h_p4_f16(**kwargs): | |
| return UViT( | |
| patch_size=4, | |
| block_out_channels=(128, 256), | |
| layers_per_block=2, | |
| downsample_before_mid=True, | |
| mid_layers=32, | |
| mid_num_heads=16, | |
| mid_dim=1280, | |
| mid_mlp_ratio=4, | |
| mid_qkv_bias=True, | |
| **kwargs | |
| ) | |
| def uvit_b_p4_f16_longskip(**kwargs): | |
| return UViT( | |
| patch_size=4, | |
| block_out_channels=(128, 256), | |
| layers_per_block=2, | |
| downsample_before_mid=True, | |
| mid_layers=13, | |
| mid_num_heads=12, | |
| mid_dim=768, | |
| mid_mlp_ratio=4, | |
| mid_qkv_bias=True, | |
| mid_use_long_skip=True, | |
| **kwargs | |
| ) | |
| def uvit_l_p4_f16_longskip(**kwargs): | |
| return UViT( | |
| patch_size=4, | |
| block_out_channels=(128, 256), | |
| layers_per_block=2, | |
| downsample_before_mid=True, | |
| mid_layers=25, | |
| mid_num_heads=16, | |
| mid_dim=1024, | |
| mid_mlp_ratio=4, | |
| mid_qkv_bias=True, | |
| mid_use_long_skip=True, | |
| **kwargs | |
| ) | |
| def uvit_b_p4_f8(**kwargs): | |
| return UViT( | |
| patch_size=4, | |
| block_out_channels=(128, 256), | |
| layers_per_block=2, | |
| downsample_before_mid=False, | |
| mid_layers=12, | |
| mid_num_heads=12, | |
| mid_dim=768, | |
| mid_mlp_ratio=4, | |
| mid_qkv_bias=True, | |
| **kwargs | |
| ) | |
| def uvit_l_p4_f8(**kwargs): | |
| return UViT( | |
| patch_size=4, | |
| block_out_channels=(128, 256), | |
| layers_per_block=2, | |
| downsample_before_mid=False, | |
| mid_layers=24, | |
| mid_num_heads=16, | |
| mid_dim=1024, | |
| mid_mlp_ratio=4, | |
| mid_qkv_bias=True, | |
| **kwargs | |
| ) | |
| def uvit_b_p4_f16_extraconv(**kwargs): | |
| return UViT( | |
| patch_size=4, | |
| block_out_channels=(128, 256, 512), | |
| layers_per_block=2, | |
| downsample_before_mid=False, | |
| mid_layers=12, | |
| mid_num_heads=12, | |
| mid_dim=768, | |
| mid_mlp_ratio=4, | |
| mid_qkv_bias=True, | |
| **kwargs | |
| ) | |
| def uvit_l_p4_f16_extraconv(**kwargs): | |
| return UViT( | |
| patch_size=4, | |
| block_out_channels=(128, 256, 512), | |
| layers_per_block=2, | |
| downsample_before_mid=False, | |
| mid_layers=24, | |
| mid_num_heads=16, | |
| mid_dim=1024, | |
| mid_mlp_ratio=4, | |
| mid_qkv_bias=True, | |
| **kwargs | |
| ) | |