tellurion's picture
initialize huggingface space demo
d066167
import torch
import torch.nn as nn
from functools import partial
from einops import rearrange
from refnet.util import checkpoint_wrapper, exists
from refnet.modules.layers import FeedForward, Normalize, zero_module, RMSNorm
from refnet.modules.attention import MemoryEfficientAttention, MultiModalAttention, MultiScaleCausalAttention
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
"vanilla": MemoryEfficientAttention,
"multi-scale": MultiScaleCausalAttention,
"multi-modal": MultiModalAttention,
}
def __init__(
self,
dim,
n_heads = None,
d_head = 64,
dropout = 0.,
context_dim = None,
gated_ff = True,
ff_mult = 4,
checkpoint = True,
disable_self_attn = False,
disable_cross_attn = False,
self_attn_type = "vanilla",
cross_attn_type = "vanilla",
rotary_positional_embedding = False,
context_dim_2 = None,
casual_self_attn = False,
casual_cross_attn = False,
qk_norm = False,
norm_type = "layer",
):
super().__init__()
assert self_attn_type in self.ATTENTION_MODES
assert cross_attn_type in self.ATTENTION_MODES
self_attn_cls = self.ATTENTION_MODES[self_attn_type]
crossattn_cls = self.ATTENTION_MODES[cross_attn_type]
if norm_type == "layer":
norm_cls = nn.LayerNorm
elif norm_type == "rms":
norm_cls = RMSNorm
else:
raise NotImplementedError(f"Normalization {norm_type} is not implemented.")
self.dim = dim
self.disable_self_attn = disable_self_attn
self.disable_cross_attn = disable_cross_attn
self.attn1 = self_attn_cls(
query_dim = dim,
heads = n_heads,
dim_head = d_head,
dropout = dropout,
context_dim = context_dim if self.disable_self_attn else None,
casual = casual_self_attn,
rope = rotary_positional_embedding,
qk_norm = qk_norm
)
self.attn2 = crossattn_cls(
query_dim = dim,
context_dim = context_dim,
context_dim_2 = context_dim_2,
heads = n_heads,
dim_head = d_head,
dropout = dropout,
casual = casual_cross_attn
) if not disable_cross_attn else None
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, mult=ff_mult)
self.norm1 = norm_cls(dim)
self.norm2 = norm_cls(dim) if not disable_cross_attn else None
self.norm3 = norm_cls(dim)
self.reference_scale = 1
self.scale_factor = None
self.checkpoint = checkpoint
@checkpoint_wrapper
def forward(self, x, context=None, mask=None, emb=None, **kwargs):
x = self.attn1(self.norm1(x), **kwargs) + x
if not self.disable_cross_attn:
x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor) + x
x = self.ff(self.norm3(x)) + x
return x
class SelfInjectedTransformerBlock(BasicTransformerBlock):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.bank = None
self.time_proj = None
self.injection_type = "concat"
self.forward_without_bank = super().forward
@checkpoint_wrapper
def forward(self, x, context=None, mask=None, emb=None, **kwargs):
if exists(self.bank):
bank = self.bank
if bank.shape[0] != x.shape[0]:
bank = bank.repeat(x.shape[0], 1, 1)
if exists(self.time_proj) and exists(emb):
bank = bank + self.time_proj(emb).unsqueeze(1)
x_in = self.norm1(x)
self.attn1.mask_threshold = self.attn2.mask_threshold
x = self.attn1(
x_in,
torch.cat([x_in, bank], 1) if self.injection_type == "concat" else x_in + bank,
mask = mask,
scale_factor = self.scale_factor,
**kwargs
) + x
x = self.attn2(
self.norm2(x),
context,
mask = mask,
scale = self.reference_scale,
scale_factor = self.scale_factor
) + x
x = self.ff(self.norm3(x)) + x
else:
x = self.forward_without_bank(x, context, mask, emb)
return x
class SelfTransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
dropout = 0.,
mlp_ratio = 4,
checkpoint = True,
casual_attn = False,
reshape = True
):
super().__init__()
self.attn = MemoryEfficientAttention(query_dim=dim, heads=dim//dim_head, dropout=dropout, casual=casual_attn)
self.ff = nn.Sequential(
nn.Linear(dim, dim * mlp_ratio),
nn.SiLU(),
zero_module(nn.Linear(dim * mlp_ratio, dim))
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.reshape = reshape
self.checkpoint = checkpoint
@checkpoint_wrapper
def forward(self, x, context=None):
b, c, h, w = x.shape
if self.reshape:
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
x = self.attn(self.norm1(x), context if exists(context) else None) + x
x = self.ff(self.norm2(x)) + x
if self.reshape:
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
return x
class Transformer(nn.Module):
transformer_type = {
"vanilla": BasicTransformerBlock,
"self-injection": SelfInjectedTransformerBlock,
}
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, use_linear=False,
use_checkpoint=True, type="vanilla", transformer_config=None, **kwargs):
super().__init__()
transformer_block = self.transformer_type[type]
if not isinstance(context_dim, list):
context_dim = [context_dim]
if isinstance(context_dim, list):
if depth != len(context_dim):
context_dim = depth * [context_dim[0]]
proj_layer = nn.Linear if use_linear else partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
inner_dim = n_heads * d_head
self.in_channels = in_channels
self.proj_in = proj_layer(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList([
transformer_block(
inner_dim,
n_heads,
d_head,
dropout = dropout,
context_dim = context_dim[d],
checkpoint = use_checkpoint,
**(transformer_config or {}),
**kwargs
) for d in range(depth)
])
self.proj_out = zero_module(proj_layer(inner_dim, in_channels))
self.norm = Normalize(in_channels)
self.use_linear = use_linear
def forward(self, x, context=None, mask=None, emb=None, *args, **additional_context):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, mask=mask, emb=emb, grid_size=(h, w), *args, **additional_context)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
def SpatialTransformer(*args, **kwargs):
return Transformer(type="vanilla", *args, **kwargs)
def SelfInjectTransformer(*args, **kwargs):
return Transformer(type="self-injection", *args, **kwargs)