Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| import math | |
| from torch import einsum | |
| from einops import rearrange, repeat | |
| from .basic_transformer_block import PatchedBasicTransformerBlock as BasicTransformerBlock | |
| def Normalize(in_channels): | |
| return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
| def init_(tensor): | |
| dim = tensor.shape[-1] | |
| std = 1 / math.sqrt(dim) | |
| tensor.uniform_(-std, std) | |
| return tensor | |
| def zero_module(module): | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class SpatialTransformer(nn.Module): | |
| """ | |
| Transformer block for image-like data. | |
| First, project the input (aka embedding) | |
| and reshape to b, t, d. | |
| Then apply standard transformer action. | |
| Finally, reshape to image | |
| NEW: use_linear for more efficiency instead of the 1x1 convs | |
| """ | |
| def __init__(self, in_channels, n_heads, d_head, | |
| depth=1, dropout=0., context_dim=None, | |
| disable_self_attn=False, use_linear=False, | |
| use_checkpoint=True): | |
| super().__init__() | |
| if context_dim is not None and not isinstance(context_dim, list): | |
| context_dim = [context_dim] | |
| self.in_channels = in_channels | |
| inner_dim = n_heads * d_head | |
| self.norm = Normalize(in_channels) | |
| if not use_linear: | |
| self.proj_in = nn.Conv2d(in_channels, | |
| inner_dim, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0) | |
| else: | |
| self.proj_in = nn.Linear(in_channels, inner_dim) | |
| self.transformer_blocks = nn.ModuleList( | |
| [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], | |
| disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) | |
| for d in range(depth)] | |
| ) | |
| if not use_linear: | |
| self.proj_out = zero_module(nn.Conv2d(inner_dim, | |
| in_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0)) | |
| else: | |
| self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) | |
| self.use_linear = use_linear | |
| def forward(self, x, context=None): | |
| # note: if no context is given, cross-attention defaults to self-attention | |
| if not isinstance(context, list): | |
| context = [context] | |
| b, c, h, w = x.shape | |
| x_in = x | |
| x = self.norm(x) | |
| if not self.use_linear: | |
| x = self.proj_in(x) | |
| x = rearrange(x, 'b c h w -> b (h w) c').contiguous() | |
| if self.use_linear: | |
| x = self.proj_in(x) | |
| for i, block in enumerate(self.transformer_blocks): | |
| x = block(x, context=context[i]) | |
| if self.use_linear: | |
| x = self.proj_out(x) | |
| x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() | |
| if not self.use_linear: | |
| x = self.proj_out(x) | |
| return x + x_in | |