FashionFlow / models /diffusion_model.py
tasin
init
f075308
import math
import functools
from operator import mul
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def mul_reduce(tup):
return functools.reduce(mul, tup)
def divisible_by(numer, denom):
return (numer % denom) == 0
mlist = nn.ModuleList
# for time conditioning
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, theta=10000):
super().__init__()
self.theta = theta
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2
emb = math.log(self.theta) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device, dtype=dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim=-1).type(dtype)
# layernorm 3d
class ChanLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(dim, 1, 1, 1))
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) * var.clamp(min=eps).rsqrt() * self.g
# feedforward
def shift_token(t):
t, t_shift = t.chunk(2, dim=1)
t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value=0.)
return torch.cat((t, t_shift), dim=1)
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
inner_dim = int(dim * mult * 2 / 3)
self.proj_in = nn.Sequential(
nn.Conv3d(dim, inner_dim * 2, 1, bias=False),
GEGLU()
)
self.proj_out = nn.Sequential(
ChanLayerNorm(inner_dim),
nn.Conv3d(inner_dim, dim, 1, bias=False)
)
def forward(self, x, enable_time=True):
x = self.proj_in(x)
if enable_time:
x = shift_token(x)
return self.proj_out(x)
# best relative positional encoding
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(
self,
*,
dim,
heads,
num_dims=1,
layers=2
):
super().__init__()
self.num_dims = num_dims
self.net = nn.ModuleList([])
self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))
self.net.append(nn.Linear(dim, heads))
@property
def device(self):
return next(self.parameters()).device
def forward(self, *dimensions):
device = self.device
shape = torch.tensor(dimensions, device=device)
rel_pos_shape = 2 * shape - 1
# calculate strides
strides = torch.flip(rel_pos_shape, (0,)).cumprod(dim=-1)
strides = torch.flip(F.pad(strides, (1, -1), value=1), (0,))
# get all positions and calculate all the relative distances
positions = [torch.arange(d, device=device) for d in dimensions]
grid = torch.stack(torch.meshgrid(*positions, indexing='ij'), dim=-1)
grid = rearrange(grid, '... c -> (...) c')
rel_dist = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
# get all relative positions across all dimensions
rel_positions = [torch.arange(-d + 1, d, device=device) for d in dimensions]
rel_pos_grid = torch.stack(torch.meshgrid(*rel_positions, indexing='ij'), dim=-1)
rel_pos_grid = rearrange(rel_pos_grid, '... c -> (...) c')
# mlp input
bias = rel_pos_grid.float()
for layer in self.net:
bias = layer(bias)
# convert relative distances to indices of the bias
rel_dist += (shape - 1) # make sure all positive
rel_dist *= strides
rel_dist_indices = rel_dist.sum(dim=-1)
# now select the bias for each unique relative position combination
bias = bias[rel_dist_indices]
return rearrange(bias, 'i j h -> h i j')
# helper classes
class CrossAttention(nn.Module):
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
super().__init__()
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads
self.d_head = d_embed // n_heads
def forward(self, x, y):
input_shape = x.shape
batch_size, sequence_length, d_embed = input_shape
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
q = self.q_proj(x)
k = self.k_proj(y)
v = self.v_proj(y)
q = q.view(interim_shape).transpose(1, 2)
k = k.view(interim_shape).transpose(1, 2)
v = v.view(interim_shape).transpose(1, 2)
weight = q @ k.transpose(-1, -2)
weight /= math.sqrt(self.d_head)
weight = F.softmax(weight, dim=-1)
output = weight @ v
output = output.transpose(1, 2).contiguous()
output = output.view(input_shape)
output = self.out_proj(output)
return output
class AttentionBlock(nn.Module):
def __init__(self, n_head: int, n_embd: int, d_context=768):
super().__init__()
channels = n_head * n_embd
#self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
#self.conv_input = PseudoConv3d(channels, channels, 1)
self.layernorm_2 = nn.LayerNorm(channels)
self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
self.layernorm_3 = nn.LayerNorm(channels)
self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
self.linear_geglu_2 = nn.Linear(4 * channels, channels)
self.conv_output = PseudoConv3d(channels, channels, 1, bias=False)
def forward(self, x, context):
b, c, *_, h, w = x.shape
#x = self.groupnorm(x)
#x = self.conv_input(x)
x = rearrange(x, 'b c f h w -> b (h w f) c')
residue_short = x
x = self.layernorm_2(x)
x = self.attention_2(x, context)
x += residue_short
residue_short = x
x = self.layernorm_3(x)
x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
x = x * F.gelu(gate)
x = self.linear_geglu_2(x)
x += residue_short
x = rearrange(x, 'b (h w f) c -> b c f h w', b=b, c=c, h=h, w=w)
x = self.conv_output(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
nn.init.zeros_(self.to_out.weight.data) # identity with skip connection
def forward(
self,
x,
rel_pos_bias=None
):
x = self.norm(x)
q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
if exists(rel_pos_bias):
sim = sim + rel_pos_bias
attn = sim.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# main contribution - pseudo 3d conv
class PseudoConv3d(nn.Module):
def __init__(
self,
dim,
dim_out=None,
kernel_size=3,
*,
temporal_kernel_size=None,
**kwargs
):
super().__init__()
dim_out = default(dim_out, dim)
temporal_kernel_size = default(temporal_kernel_size, kernel_size)
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size=kernel_size, padding=kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size=temporal_kernel_size,
padding=temporal_kernel_size // 2) if kernel_size > 1 else None
if exists(self.temporal_conv):
nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)
def forward(
self,
x,
enable_time=True
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
enable_time &= is_video
if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')
x = self.spatial_conv(x)
if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b=b)
if not enable_time or not exists(self.temporal_conv):
return x
x = rearrange(x, 'b c f h w -> (b h w) c f')
x = self.temporal_conv(x)
x = rearrange(x, '(b h w) c f -> b c f h w', h=h, w=w)
return x
# factorized spatial temporal attention from Ho et al.
class SpatioTemporalAttention(nn.Module):
def __init__(
self,
dim,
*,
dim_head=64,
heads=8,
add_feed_forward=True,
ff_mult=4
):
super().__init__()
self.spatial_attn = Attention(dim=dim, dim_head=dim_head, heads=heads)
self.spatial_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=2)
self.temporal_attn = Attention(dim=dim, dim_head=dim_head, heads=heads)
self.temporal_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=1)
self.has_feed_forward = add_feed_forward
if not add_feed_forward:
return
self.ff = FeedForward(dim=dim, mult=ff_mult)
def forward(
self,
x,
enable_time=True
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
enable_time &= is_video
if is_video:
x = rearrange(x, 'b c f h w -> (b f) (h w) c')
else:
x = rearrange(x, 'b c h w -> b (h w) c')
space_rel_pos_bias = self.spatial_rel_pos_bias(h, w)
x = self.spatial_attn(x, rel_pos_bias=space_rel_pos_bias) + x
if is_video:
x = rearrange(x, '(b f) (h w) c -> b c f h w', b=b, h=h, w=w)
else:
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
if enable_time:
x = rearrange(x, 'b c f h w -> (b h w) f c')
time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1])
x = self.temporal_attn(x, rel_pos_bias=time_rel_pos_bias) + x
x = rearrange(x, '(b h w) f c -> b c f h w', w=w, h=h)
if self.has_feed_forward:
x = self.ff(x, enable_time=enable_time) + x
return x
# resnet block
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
kernel_size=3,
temporal_kernel_size=None,
groups=8
):
super().__init__()
self.project = PseudoConv3d(dim, dim_out, 3)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(
self,
x,
scale_shift=None,
enable_time=False
):
x = self.project(x, enable_time=enable_time)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
return self.act(x)
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
timestep_cond_dim=None,
groups=8
):
super().__init__()
self.timestep_mlp = None
if exists(timestep_cond_dim):
self.timestep_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(timestep_cond_dim, dim_out * 2)
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(
self,
x,
timestep_emb=None,
enable_time=True
):
assert not (exists(timestep_emb) ^ exists(self.timestep_mlp))
scale_shift = None
if exists(self.timestep_mlp) and exists(timestep_emb):
time_emb = self.timestep_mlp(timestep_emb)
to_einsum_eq = 'b c 1 1 1' if x.ndim == 5 else 'b c 1 1'
time_emb = rearrange(time_emb, f'b c -> {to_einsum_eq}')
scale_shift = time_emb.chunk(2, dim=1)
h = self.block1(x, scale_shift=scale_shift, enable_time=enable_time)
h = self.block2(h, enable_time=enable_time)
return h + self.res_conv(x)
# pixelshuffle upsamples and downsamples
# where time dimension can be configured
class Downsample(nn.Module):
def __init__(
self,
dim,
downsample_space=True,
downsample_time=False,
nonlin=False
):
super().__init__()
assert downsample_space or downsample_time
self.down_space = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2),
nn.Conv2d(dim * 4, dim, 1, bias=False),
nn.SiLU() if nonlin else nn.Identity()
) if downsample_space else None
self.down_time = nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p=2),
nn.Conv3d(dim * 2, dim, 1, bias=False),
nn.SiLU() if nonlin else nn.Identity()
) if downsample_time else None
def forward(
self,
x,
enable_time=True
):
is_video = x.ndim == 5
if is_video:
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')
if exists(self.down_space):
x = self.down_space(x)
if is_video:
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')
if not is_video or not exists(self.down_time) or not enable_time:
return x
x = self.down_time(x)
return x
class Upsample(nn.Module):
def __init__(
self,
dim,
upsample_space=True,
upsample_time=False,
nonlin=False
):
super().__init__()
assert upsample_space or upsample_time
self.up_space = nn.Sequential(
nn.Conv2d(dim, dim * 4, 1),
nn.SiLU() if nonlin else nn.Identity(),
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1=2, p2=2)
) if upsample_space else None
self.up_time = nn.Sequential(
nn.Conv3d(dim, dim * 2, 1),
nn.SiLU() if nonlin else nn.Identity(),
Rearrange('b (c p) f h w -> b c (f p) h w', p=2)
) if upsample_time else None
self.init_()
def init_(self):
if exists(self.up_space):
self.init_conv_(self.up_space[0], 4)
if exists(self.up_time):
self.init_conv_(self.up_time[0], 2)
def init_conv_(self, conv, factor):
o, *remain_dims = conv.weight.shape
conv_weight = torch.empty(o // factor, *remain_dims)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r=factor)
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(
self,
x,
enable_time=True
):
is_video = x.ndim == 5
if is_video:
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')
if exists(self.up_space):
x = self.up_space(x)
if is_video:
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')
if not is_video or not exists(self.up_time) or not enable_time:
return x
x = self.up_time(x)
return x
class SpaceTimeUnet(nn.Module):
def __init__(
self,
*,
dim,
channels=4,
dim_mult=(1, 2, 4, 8),
self_attns=(False, False, False, True),
temporal_compression=(False, True, True, True),
resnet_block_depths=(2, 2, 2, 2),
attn_dim_head=64,
attn_heads=8,
condition_on_timestep=False,
):
super().__init__()
assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths)
num_layers = len(dim_mult)
dims = [dim, *map(lambda mult: mult * dim, dim_mult)]
dim_in_out = zip(dims[:-1], dims[1:])
# determine the valid multiples of the image size and frames of the video
self.frame_multiple = 2 ** sum(tuple(map(int, temporal_compression)))
self.image_size_multiple = 2 ** num_layers
# timestep conditioning for DDPM, not to be confused with the time dimension of the video
self.to_timestep_cond = None
timestep_cond_dim = (dim * 4) if condition_on_timestep else None
if condition_on_timestep:
self.to_timestep_cond = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, timestep_cond_dim),
nn.SiLU()
)
# Cross Attention
cross_attention_D1 = AttentionBlock(1, 64) # 64
cross_attention_D2 = AttentionBlock(1, 128) # 128
cross_attention_D3 = AttentionBlock(2, 128) # 256
cross_attention_D4 = AttentionBlock(4, 128) # 512
cross_attention_U1 = AttentionBlock(4, 64) # 256
cross_attention_U2 = AttentionBlock(2, 64) # 128
cross_attention_U3 = AttentionBlock(1, 64) # 64
cross_attention_U4 = AttentionBlock(1, 64) # 64
cross_attns_down = (cross_attention_D1, cross_attention_D2, cross_attention_D3, cross_attention_D4)
cross_attns_up = (cross_attention_U4, cross_attention_U3, cross_attention_U2, cross_attention_U1)
# layers
self.downs = mlist([])
self.ups = mlist([])
attn_kwargs = dict(
dim_head=attn_dim_head,
heads=attn_heads
)
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim)
self.mid_attn = SpatioTemporalAttention(dim=mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim)
for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth, cross_attns_d, cross_attns_u in zip(range(num_layers),
self_attns,
dim_in_out,
temporal_compression,
resnet_block_depths,
cross_attns_down,
cross_attns_up):
assert resnet_block_depth >= 1
self.downs.append(mlist([
ResnetBlock(dim_in, dim_out, timestep_cond_dim=timestep_cond_dim),
mlist([ResnetBlock(dim_out, dim_out) for _ in range(resnet_block_depth)]),
SpatioTemporalAttention(dim=dim_out, **attn_kwargs) if self_attend else None,
Downsample(dim_out, downsample_time=compress_time),
cross_attns_d if exists(cross_attns_d) else None
]))
self.ups.append(mlist([
ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim=timestep_cond_dim),
mlist(
[ResnetBlock(dim_in + (dim_out if ind == 0 else 0), dim_in) for ind in range(resnet_block_depth)]),
SpatioTemporalAttention(dim=dim_in, **attn_kwargs) if self_attend else None,
Upsample(dim_out, upsample_time=compress_time),
cross_attns_u if exists(cross_attns_u) else None
]))
self.skip_scale = 2 ** -0.5 # paper shows faster convergence
self.conv_in = PseudoConv3d(dim=channels, dim_out=dim, kernel_size=7, temporal_kernel_size=3)
self.conv_out = PseudoConv3d(dim=dim, dim_out=channels, kernel_size=3, temporal_kernel_size=3)
def forward(
self,
x,
clip_vae_embed,
timestep=None,
enable_time=True
):
assert not (exists(self.to_timestep_cond) ^ exists(timestep))
is_video = x.ndim == 5
if enable_time and is_video:
frames = x.shape[2]
assert divisible_by(frames,
self.frame_multiple), f'number of frames on the video ({frames}) must be divisible by the frame multiple ({self.frame_multiple})'
height, width = x.shape[-2:]
assert divisible_by(height, self.image_size_multiple) and divisible_by(width,
self.image_size_multiple), f'height and width of the image or video must be a multiple of {self.image_size_multiple}'
# main logic
t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None
x = self.conv_in(x, enable_time=enable_time)
hiddens = []
for init_block, blocks, maybe_attention, downsample, cross_attn in self.downs:
x = init_block(x, t, enable_time=enable_time)
hiddens.append(x.clone())
for block in blocks:
x = block(x, enable_time=enable_time)
if exists(maybe_attention):
x = maybe_attention(x, enable_time=enable_time) # only happens in the last layer
hiddens.append(x.clone())
x = downsample(x, enable_time=enable_time)
if exists(cross_attn):
x = cross_attn(x, clip_vae_embed)
x = self.mid_block1(x, t, enable_time=enable_time)
x = self.mid_attn(x, enable_time=enable_time)
x = self.mid_block2(x, t, enable_time=enable_time)
for init_block, blocks, maybe_attention, upsample, cross_attn in reversed(self.ups):
x = upsample(x, enable_time=enable_time)
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1)
x = init_block(x, t, enable_time=enable_time)
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1)
for block in blocks:
x = block(x, enable_time=enable_time)
if exists(maybe_attention):
x = maybe_attention(x, enable_time=enable_time)
if exists(cross_attn):
x = cross_attn(x, clip_vae_embed)
x = self.conv_out(x, enable_time=enable_time)
return x
if __name__ == '__main__':
Net = SpaceTimeUnet(
dim=64,
channels=3,
dim_mult=(1, 2, 4, 8),
temporal_compression=(False, False, False, True),
self_attns=(False, False, False, True),
condition_on_timestep=False)
x = torch.randn([1,8,3,32,32])
sample_output = Net(x.permute(0, 2, 1, 3, 4))