Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| from functools import partial | |
| from einops import rearrange | |
| from refnet.util import checkpoint_wrapper, exists | |
| from refnet.modules.layers import FeedForward, Normalize, zero_module, RMSNorm | |
| from refnet.modules.attention import MemoryEfficientAttention, MultiModalAttention, MultiScaleCausalAttention | |
| class BasicTransformerBlock(nn.Module): | |
| ATTENTION_MODES = { | |
| "vanilla": MemoryEfficientAttention, | |
| "multi-scale": MultiScaleCausalAttention, | |
| "multi-modal": MultiModalAttention, | |
| } | |
| def __init__( | |
| self, | |
| dim, | |
| n_heads = None, | |
| d_head = 64, | |
| dropout = 0., | |
| context_dim = None, | |
| gated_ff = True, | |
| ff_mult = 4, | |
| checkpoint = True, | |
| disable_self_attn = False, | |
| disable_cross_attn = False, | |
| self_attn_type = "vanilla", | |
| cross_attn_type = "vanilla", | |
| rotary_positional_embedding = False, | |
| context_dim_2 = None, | |
| casual_self_attn = False, | |
| casual_cross_attn = False, | |
| qk_norm = False, | |
| norm_type = "layer", | |
| ): | |
| super().__init__() | |
| assert self_attn_type in self.ATTENTION_MODES | |
| assert cross_attn_type in self.ATTENTION_MODES | |
| self_attn_cls = self.ATTENTION_MODES[self_attn_type] | |
| crossattn_cls = self.ATTENTION_MODES[cross_attn_type] | |
| if norm_type == "layer": | |
| norm_cls = nn.LayerNorm | |
| elif norm_type == "rms": | |
| norm_cls = RMSNorm | |
| else: | |
| raise NotImplementedError(f"Normalization {norm_type} is not implemented.") | |
| self.dim = dim | |
| self.disable_self_attn = disable_self_attn | |
| self.disable_cross_attn = disable_cross_attn | |
| self.attn1 = self_attn_cls( | |
| query_dim = dim, | |
| heads = n_heads, | |
| dim_head = d_head, | |
| dropout = dropout, | |
| context_dim = context_dim if self.disable_self_attn else None, | |
| casual = casual_self_attn, | |
| rope = rotary_positional_embedding, | |
| qk_norm = qk_norm | |
| ) | |
| self.attn2 = crossattn_cls( | |
| query_dim = dim, | |
| context_dim = context_dim, | |
| context_dim_2 = context_dim_2, | |
| heads = n_heads, | |
| dim_head = d_head, | |
| dropout = dropout, | |
| casual = casual_cross_attn | |
| ) if not disable_cross_attn else None | |
| self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, mult=ff_mult) | |
| self.norm1 = norm_cls(dim) | |
| self.norm2 = norm_cls(dim) if not disable_cross_attn else None | |
| self.norm3 = norm_cls(dim) | |
| self.reference_scale = 1 | |
| self.scale_factor = None | |
| self.checkpoint = checkpoint | |
| def forward(self, x, context=None, mask=None, emb=None, **kwargs): | |
| x = self.attn1(self.norm1(x), **kwargs) + x | |
| if not self.disable_cross_attn: | |
| x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor) + x | |
| x = self.ff(self.norm3(x)) + x | |
| return x | |
| class SelfInjectedTransformerBlock(BasicTransformerBlock): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.bank = None | |
| self.time_proj = None | |
| self.injection_type = "concat" | |
| self.forward_without_bank = super().forward | |
| def forward(self, x, context=None, mask=None, emb=None, **kwargs): | |
| if exists(self.bank): | |
| bank = self.bank | |
| if bank.shape[0] != x.shape[0]: | |
| bank = bank.repeat(x.shape[0], 1, 1) | |
| if exists(self.time_proj) and exists(emb): | |
| bank = bank + self.time_proj(emb).unsqueeze(1) | |
| x_in = self.norm1(x) | |
| self.attn1.mask_threshold = self.attn2.mask_threshold | |
| x = self.attn1( | |
| x_in, | |
| torch.cat([x_in, bank], 1) if self.injection_type == "concat" else x_in + bank, | |
| mask = mask, | |
| scale_factor = self.scale_factor, | |
| **kwargs | |
| ) + x | |
| x = self.attn2( | |
| self.norm2(x), | |
| context, | |
| mask = mask, | |
| scale = self.reference_scale, | |
| scale_factor = self.scale_factor | |
| ) + x | |
| x = self.ff(self.norm3(x)) + x | |
| else: | |
| x = self.forward_without_bank(x, context, mask, emb) | |
| return x | |
| class SelfTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_head = 64, | |
| dropout = 0., | |
| mlp_ratio = 4, | |
| checkpoint = True, | |
| casual_attn = False, | |
| reshape = True | |
| ): | |
| super().__init__() | |
| self.attn = MemoryEfficientAttention(query_dim=dim, heads=dim//dim_head, dropout=dropout, casual=casual_attn) | |
| self.ff = nn.Sequential( | |
| nn.Linear(dim, dim * mlp_ratio), | |
| nn.SiLU(), | |
| zero_module(nn.Linear(dim * mlp_ratio, dim)) | |
| ) | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.reshape = reshape | |
| self.checkpoint = checkpoint | |
| def forward(self, x, context=None): | |
| b, c, h, w = x.shape | |
| if self.reshape: | |
| x = rearrange(x, 'b c h w -> b (h w) c').contiguous() | |
| x = self.attn(self.norm1(x), context if exists(context) else None) + x | |
| x = self.ff(self.norm2(x)) + x | |
| if self.reshape: | |
| x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() | |
| return x | |
| class Transformer(nn.Module): | |
| transformer_type = { | |
| "vanilla": BasicTransformerBlock, | |
| "self-injection": SelfInjectedTransformerBlock, | |
| } | |
| def __init__(self, in_channels, n_heads, d_head, | |
| depth=1, dropout=0., context_dim=None, use_linear=False, | |
| use_checkpoint=True, type="vanilla", transformer_config=None, **kwargs): | |
| super().__init__() | |
| transformer_block = self.transformer_type[type] | |
| if not isinstance(context_dim, list): | |
| context_dim = [context_dim] | |
| if isinstance(context_dim, list): | |
| if depth != len(context_dim): | |
| context_dim = depth * [context_dim[0]] | |
| proj_layer = nn.Linear if use_linear else partial(nn.Conv2d, kernel_size=1, stride=1, padding=0) | |
| inner_dim = n_heads * d_head | |
| self.in_channels = in_channels | |
| self.proj_in = proj_layer(in_channels, inner_dim) | |
| self.transformer_blocks = nn.ModuleList([ | |
| transformer_block( | |
| inner_dim, | |
| n_heads, | |
| d_head, | |
| dropout = dropout, | |
| context_dim = context_dim[d], | |
| checkpoint = use_checkpoint, | |
| **(transformer_config or {}), | |
| **kwargs | |
| ) for d in range(depth) | |
| ]) | |
| self.proj_out = zero_module(proj_layer(inner_dim, in_channels)) | |
| self.norm = Normalize(in_channels) | |
| self.use_linear = use_linear | |
| def forward(self, x, context=None, mask=None, emb=None, *args, **additional_context): | |
| # note: if no context is given, cross-attention defaults to self-attention | |
| b, c, h, w = x.shape | |
| x_in = x | |
| x = self.norm(x) | |
| if not self.use_linear: | |
| x = self.proj_in(x) | |
| x = rearrange(x, 'b c h w -> b (h w) c').contiguous() | |
| if self.use_linear: | |
| x = self.proj_in(x) | |
| for i, block in enumerate(self.transformer_blocks): | |
| x = block(x, context=context, mask=mask, emb=emb, grid_size=(h, w), *args, **additional_context) | |
| if self.use_linear: | |
| x = self.proj_out(x) | |
| x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() | |
| if not self.use_linear: | |
| x = self.proj_out(x) | |
| return x + x_in | |
| def SpatialTransformer(*args, **kwargs): | |
| return Transformer(type="vanilla", *args, **kwargs) | |
| def SelfInjectTransformer(*args, **kwargs): | |
| return Transformer(type="self-injection", *args, **kwargs) | |