Spaces:
Sleeping
Sleeping
| import math | |
| import functools | |
| from operator import mul | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, einsum | |
| from einops import rearrange, repeat, pack, unpack | |
| from einops.layers.torch import Rearrange | |
| # helper functions | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def mul_reduce(tup): | |
| return functools.reduce(mul, tup) | |
| def divisible_by(numer, denom): | |
| return (numer % denom) == 0 | |
| mlist = nn.ModuleList | |
| # for time conditioning | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim, theta=10000): | |
| super().__init__() | |
| self.theta = theta | |
| self.dim = dim | |
| def forward(self, x): | |
| dtype, device = x.dtype, x.device | |
| assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type' | |
| half_dim = self.dim // 2 | |
| emb = math.log(self.theta) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device, dtype=dtype) * -emb) | |
| emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') | |
| return torch.cat((emb.sin(), emb.cos()), dim=-1).type(dtype) | |
| # layernorm 3d | |
| class ChanLayerNorm(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.g = nn.Parameter(torch.ones(dim, 1, 1, 1)) | |
| def forward(self, x): | |
| eps = 1e-5 if x.dtype == torch.float32 else 1e-3 | |
| var = torch.var(x, dim=1, unbiased=False, keepdim=True) | |
| mean = torch.mean(x, dim=1, keepdim=True) | |
| return (x - mean) * var.clamp(min=eps).rsqrt() * self.g | |
| # feedforward | |
| def shift_token(t): | |
| t, t_shift = t.chunk(2, dim=1) | |
| t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value=0.) | |
| return torch.cat((t, t_shift), dim=1) | |
| class GEGLU(nn.Module): | |
| def forward(self, x): | |
| x, gate = x.chunk(2, dim=1) | |
| return x * F.gelu(gate) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, mult=4): | |
| super().__init__() | |
| inner_dim = int(dim * mult * 2 / 3) | |
| self.proj_in = nn.Sequential( | |
| nn.Conv3d(dim, inner_dim * 2, 1, bias=False), | |
| GEGLU() | |
| ) | |
| self.proj_out = nn.Sequential( | |
| ChanLayerNorm(inner_dim), | |
| nn.Conv3d(inner_dim, dim, 1, bias=False) | |
| ) | |
| def forward(self, x, enable_time=True): | |
| x = self.proj_in(x) | |
| if enable_time: | |
| x = shift_token(x) | |
| return self.proj_out(x) | |
| # best relative positional encoding | |
| class ContinuousPositionBias(nn.Module): | |
| """ from https://arxiv.org/abs/2111.09883 """ | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| heads, | |
| num_dims=1, | |
| layers=2 | |
| ): | |
| super().__init__() | |
| self.num_dims = num_dims | |
| self.net = nn.ModuleList([]) | |
| self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU())) | |
| for _ in range(layers - 1): | |
| self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU())) | |
| self.net.append(nn.Linear(dim, heads)) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def forward(self, *dimensions): | |
| device = self.device | |
| shape = torch.tensor(dimensions, device=device) | |
| rel_pos_shape = 2 * shape - 1 | |
| # calculate strides | |
| strides = torch.flip(rel_pos_shape, (0,)).cumprod(dim=-1) | |
| strides = torch.flip(F.pad(strides, (1, -1), value=1), (0,)) | |
| # get all positions and calculate all the relative distances | |
| positions = [torch.arange(d, device=device) for d in dimensions] | |
| grid = torch.stack(torch.meshgrid(*positions, indexing='ij'), dim=-1) | |
| grid = rearrange(grid, '... c -> (...) c') | |
| rel_dist = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c') | |
| # get all relative positions across all dimensions | |
| rel_positions = [torch.arange(-d + 1, d, device=device) for d in dimensions] | |
| rel_pos_grid = torch.stack(torch.meshgrid(*rel_positions, indexing='ij'), dim=-1) | |
| rel_pos_grid = rearrange(rel_pos_grid, '... c -> (...) c') | |
| # mlp input | |
| bias = rel_pos_grid.float() | |
| for layer in self.net: | |
| bias = layer(bias) | |
| # convert relative distances to indices of the bias | |
| rel_dist += (shape - 1) # make sure all positive | |
| rel_dist *= strides | |
| rel_dist_indices = rel_dist.sum(dim=-1) | |
| # now select the bias for each unique relative position combination | |
| bias = bias[rel_dist_indices] | |
| return rearrange(bias, 'i j h -> h i j') | |
| # helper classes | |
| class CrossAttention(nn.Module): | |
| def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True): | |
| super().__init__() | |
| self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias) | |
| self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias) | |
| self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias) | |
| self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias) | |
| self.n_heads = n_heads | |
| self.d_head = d_embed // n_heads | |
| def forward(self, x, y): | |
| input_shape = x.shape | |
| batch_size, sequence_length, d_embed = input_shape | |
| interim_shape = (batch_size, -1, self.n_heads, self.d_head) | |
| q = self.q_proj(x) | |
| k = self.k_proj(y) | |
| v = self.v_proj(y) | |
| q = q.view(interim_shape).transpose(1, 2) | |
| k = k.view(interim_shape).transpose(1, 2) | |
| v = v.view(interim_shape).transpose(1, 2) | |
| weight = q @ k.transpose(-1, -2) | |
| weight /= math.sqrt(self.d_head) | |
| weight = F.softmax(weight, dim=-1) | |
| output = weight @ v | |
| output = output.transpose(1, 2).contiguous() | |
| output = output.view(input_shape) | |
| output = self.out_proj(output) | |
| return output | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, n_head: int, n_embd: int, d_context=768): | |
| super().__init__() | |
| channels = n_head * n_embd | |
| #self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6) | |
| #self.conv_input = PseudoConv3d(channels, channels, 1) | |
| self.layernorm_2 = nn.LayerNorm(channels) | |
| self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False) | |
| self.layernorm_3 = nn.LayerNorm(channels) | |
| self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2) | |
| self.linear_geglu_2 = nn.Linear(4 * channels, channels) | |
| self.conv_output = PseudoConv3d(channels, channels, 1, bias=False) | |
| def forward(self, x, context): | |
| b, c, *_, h, w = x.shape | |
| #x = self.groupnorm(x) | |
| #x = self.conv_input(x) | |
| x = rearrange(x, 'b c f h w -> b (h w f) c') | |
| residue_short = x | |
| x = self.layernorm_2(x) | |
| x = self.attention_2(x, context) | |
| x += residue_short | |
| residue_short = x | |
| x = self.layernorm_3(x) | |
| x, gate = self.linear_geglu_1(x).chunk(2, dim=-1) | |
| x = x * F.gelu(gate) | |
| x = self.linear_geglu_2(x) | |
| x += residue_short | |
| x = rearrange(x, 'b (h w f) c -> b c f h w', b=b, c=c, h=h, w=w) | |
| x = self.conv_output(x) | |
| return x | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_head=64, | |
| heads=8 | |
| ): | |
| super().__init__() | |
| self.heads = heads | |
| self.scale = dim_head ** -0.5 | |
| inner_dim = dim_head * heads | |
| self.norm = nn.LayerNorm(dim) | |
| self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
| self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) | |
| self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
| nn.init.zeros_(self.to_out.weight.data) # identity with skip connection | |
| def forward( | |
| self, | |
| x, | |
| rel_pos_bias=None | |
| ): | |
| x = self.norm(x) | |
| q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim=-1) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v)) | |
| q = q * self.scale | |
| sim = einsum('b h i d, b h j d -> b h i j', q, k) | |
| if exists(rel_pos_bias): | |
| sim = sim + rel_pos_bias | |
| attn = sim.softmax(dim=-1) | |
| out = einsum('b h i j, b h j d -> b h i d', attn, v) | |
| out = rearrange(out, 'b h n d -> b n (h d)') | |
| return self.to_out(out) | |
| # main contribution - pseudo 3d conv | |
| class PseudoConv3d(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_out=None, | |
| kernel_size=3, | |
| *, | |
| temporal_kernel_size=None, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| dim_out = default(dim_out, dim) | |
| temporal_kernel_size = default(temporal_kernel_size, kernel_size) | |
| self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size=kernel_size, padding=kernel_size // 2) | |
| self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size=temporal_kernel_size, | |
| padding=temporal_kernel_size // 2) if kernel_size > 1 else None | |
| if exists(self.temporal_conv): | |
| nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity | |
| nn.init.zeros_(self.temporal_conv.bias.data) | |
| def forward( | |
| self, | |
| x, | |
| enable_time=True | |
| ): | |
| b, c, *_, h, w = x.shape | |
| is_video = x.ndim == 5 | |
| enable_time &= is_video | |
| if is_video: | |
| x = rearrange(x, 'b c f h w -> (b f) c h w') | |
| x = self.spatial_conv(x) | |
| if is_video: | |
| x = rearrange(x, '(b f) c h w -> b c f h w', b=b) | |
| if not enable_time or not exists(self.temporal_conv): | |
| return x | |
| x = rearrange(x, 'b c f h w -> (b h w) c f') | |
| x = self.temporal_conv(x) | |
| x = rearrange(x, '(b h w) c f -> b c f h w', h=h, w=w) | |
| return x | |
| # factorized spatial temporal attention from Ho et al. | |
| class SpatioTemporalAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| *, | |
| dim_head=64, | |
| heads=8, | |
| add_feed_forward=True, | |
| ff_mult=4 | |
| ): | |
| super().__init__() | |
| self.spatial_attn = Attention(dim=dim, dim_head=dim_head, heads=heads) | |
| self.spatial_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=2) | |
| self.temporal_attn = Attention(dim=dim, dim_head=dim_head, heads=heads) | |
| self.temporal_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=1) | |
| self.has_feed_forward = add_feed_forward | |
| if not add_feed_forward: | |
| return | |
| self.ff = FeedForward(dim=dim, mult=ff_mult) | |
| def forward( | |
| self, | |
| x, | |
| enable_time=True | |
| ): | |
| b, c, *_, h, w = x.shape | |
| is_video = x.ndim == 5 | |
| enable_time &= is_video | |
| if is_video: | |
| x = rearrange(x, 'b c f h w -> (b f) (h w) c') | |
| else: | |
| x = rearrange(x, 'b c h w -> b (h w) c') | |
| space_rel_pos_bias = self.spatial_rel_pos_bias(h, w) | |
| x = self.spatial_attn(x, rel_pos_bias=space_rel_pos_bias) + x | |
| if is_video: | |
| x = rearrange(x, '(b f) (h w) c -> b c f h w', b=b, h=h, w=w) | |
| else: | |
| x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) | |
| if enable_time: | |
| x = rearrange(x, 'b c f h w -> (b h w) f c') | |
| time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) | |
| x = self.temporal_attn(x, rel_pos_bias=time_rel_pos_bias) + x | |
| x = rearrange(x, '(b h w) f c -> b c f h w', w=w, h=h) | |
| if self.has_feed_forward: | |
| x = self.ff(x, enable_time=enable_time) + x | |
| return x | |
| # resnet block | |
| class Block(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_out, | |
| kernel_size=3, | |
| temporal_kernel_size=None, | |
| groups=8 | |
| ): | |
| super().__init__() | |
| self.project = PseudoConv3d(dim, dim_out, 3) | |
| self.norm = nn.GroupNorm(groups, dim_out) | |
| self.act = nn.SiLU() | |
| def forward( | |
| self, | |
| x, | |
| scale_shift=None, | |
| enable_time=False | |
| ): | |
| x = self.project(x, enable_time=enable_time) | |
| x = self.norm(x) | |
| if exists(scale_shift): | |
| scale, shift = scale_shift | |
| x = x * (scale + 1) + shift | |
| return self.act(x) | |
| class ResnetBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_out, | |
| *, | |
| timestep_cond_dim=None, | |
| groups=8 | |
| ): | |
| super().__init__() | |
| self.timestep_mlp = None | |
| if exists(timestep_cond_dim): | |
| self.timestep_mlp = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(timestep_cond_dim, dim_out * 2) | |
| ) | |
| self.block1 = Block(dim, dim_out, groups=groups) | |
| self.block2 = Block(dim_out, dim_out, groups=groups) | |
| self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
| def forward( | |
| self, | |
| x, | |
| timestep_emb=None, | |
| enable_time=True | |
| ): | |
| assert not (exists(timestep_emb) ^ exists(self.timestep_mlp)) | |
| scale_shift = None | |
| if exists(self.timestep_mlp) and exists(timestep_emb): | |
| time_emb = self.timestep_mlp(timestep_emb) | |
| to_einsum_eq = 'b c 1 1 1' if x.ndim == 5 else 'b c 1 1' | |
| time_emb = rearrange(time_emb, f'b c -> {to_einsum_eq}') | |
| scale_shift = time_emb.chunk(2, dim=1) | |
| h = self.block1(x, scale_shift=scale_shift, enable_time=enable_time) | |
| h = self.block2(h, enable_time=enable_time) | |
| return h + self.res_conv(x) | |
| # pixelshuffle upsamples and downsamples | |
| # where time dimension can be configured | |
| class Downsample(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| downsample_space=True, | |
| downsample_time=False, | |
| nonlin=False | |
| ): | |
| super().__init__() | |
| assert downsample_space or downsample_time | |
| self.down_space = nn.Sequential( | |
| Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2), | |
| nn.Conv2d(dim * 4, dim, 1, bias=False), | |
| nn.SiLU() if nonlin else nn.Identity() | |
| ) if downsample_space else None | |
| self.down_time = nn.Sequential( | |
| Rearrange('b c (f p) h w -> b (c p) f h w', p=2), | |
| nn.Conv3d(dim * 2, dim, 1, bias=False), | |
| nn.SiLU() if nonlin else nn.Identity() | |
| ) if downsample_time else None | |
| def forward( | |
| self, | |
| x, | |
| enable_time=True | |
| ): | |
| is_video = x.ndim == 5 | |
| if is_video: | |
| x = rearrange(x, 'b c f h w -> b f c h w') | |
| x, ps = pack([x], '* c h w') | |
| if exists(self.down_space): | |
| x = self.down_space(x) | |
| if is_video: | |
| x, = unpack(x, ps, '* c h w') | |
| x = rearrange(x, 'b f c h w -> b c f h w') | |
| if not is_video or not exists(self.down_time) or not enable_time: | |
| return x | |
| x = self.down_time(x) | |
| return x | |
| class Upsample(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| upsample_space=True, | |
| upsample_time=False, | |
| nonlin=False | |
| ): | |
| super().__init__() | |
| assert upsample_space or upsample_time | |
| self.up_space = nn.Sequential( | |
| nn.Conv2d(dim, dim * 4, 1), | |
| nn.SiLU() if nonlin else nn.Identity(), | |
| Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1=2, p2=2) | |
| ) if upsample_space else None | |
| self.up_time = nn.Sequential( | |
| nn.Conv3d(dim, dim * 2, 1), | |
| nn.SiLU() if nonlin else nn.Identity(), | |
| Rearrange('b (c p) f h w -> b c (f p) h w', p=2) | |
| ) if upsample_time else None | |
| self.init_() | |
| def init_(self): | |
| if exists(self.up_space): | |
| self.init_conv_(self.up_space[0], 4) | |
| if exists(self.up_time): | |
| self.init_conv_(self.up_time[0], 2) | |
| def init_conv_(self, conv, factor): | |
| o, *remain_dims = conv.weight.shape | |
| conv_weight = torch.empty(o // factor, *remain_dims) | |
| nn.init.kaiming_uniform_(conv_weight) | |
| conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r=factor) | |
| conv.weight.data.copy_(conv_weight) | |
| nn.init.zeros_(conv.bias.data) | |
| def forward( | |
| self, | |
| x, | |
| enable_time=True | |
| ): | |
| is_video = x.ndim == 5 | |
| if is_video: | |
| x = rearrange(x, 'b c f h w -> b f c h w') | |
| x, ps = pack([x], '* c h w') | |
| if exists(self.up_space): | |
| x = self.up_space(x) | |
| if is_video: | |
| x, = unpack(x, ps, '* c h w') | |
| x = rearrange(x, 'b f c h w -> b c f h w') | |
| if not is_video or not exists(self.up_time) or not enable_time: | |
| return x | |
| x = self.up_time(x) | |
| return x | |
| class SpaceTimeUnet(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| channels=4, | |
| dim_mult=(1, 2, 4, 8), | |
| self_attns=(False, False, False, True), | |
| temporal_compression=(False, True, True, True), | |
| resnet_block_depths=(2, 2, 2, 2), | |
| attn_dim_head=64, | |
| attn_heads=8, | |
| condition_on_timestep=False, | |
| ): | |
| super().__init__() | |
| assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths) | |
| num_layers = len(dim_mult) | |
| dims = [dim, *map(lambda mult: mult * dim, dim_mult)] | |
| dim_in_out = zip(dims[:-1], dims[1:]) | |
| # determine the valid multiples of the image size and frames of the video | |
| self.frame_multiple = 2 ** sum(tuple(map(int, temporal_compression))) | |
| self.image_size_multiple = 2 ** num_layers | |
| # timestep conditioning for DDPM, not to be confused with the time dimension of the video | |
| self.to_timestep_cond = None | |
| timestep_cond_dim = (dim * 4) if condition_on_timestep else None | |
| if condition_on_timestep: | |
| self.to_timestep_cond = nn.Sequential( | |
| SinusoidalPosEmb(dim), | |
| nn.Linear(dim, timestep_cond_dim), | |
| nn.SiLU() | |
| ) | |
| # Cross Attention | |
| cross_attention_D1 = AttentionBlock(1, 64) # 64 | |
| cross_attention_D2 = AttentionBlock(1, 128) # 128 | |
| cross_attention_D3 = AttentionBlock(2, 128) # 256 | |
| cross_attention_D4 = AttentionBlock(4, 128) # 512 | |
| cross_attention_U1 = AttentionBlock(4, 64) # 256 | |
| cross_attention_U2 = AttentionBlock(2, 64) # 128 | |
| cross_attention_U3 = AttentionBlock(1, 64) # 64 | |
| cross_attention_U4 = AttentionBlock(1, 64) # 64 | |
| cross_attns_down = (cross_attention_D1, cross_attention_D2, cross_attention_D3, cross_attention_D4) | |
| cross_attns_up = (cross_attention_U4, cross_attention_U3, cross_attention_U2, cross_attention_U1) | |
| # layers | |
| self.downs = mlist([]) | |
| self.ups = mlist([]) | |
| attn_kwargs = dict( | |
| dim_head=attn_dim_head, | |
| heads=attn_heads | |
| ) | |
| mid_dim = dims[-1] | |
| self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim) | |
| self.mid_attn = SpatioTemporalAttention(dim=mid_dim) | |
| self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim) | |
| for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth, cross_attns_d, cross_attns_u in zip(range(num_layers), | |
| self_attns, | |
| dim_in_out, | |
| temporal_compression, | |
| resnet_block_depths, | |
| cross_attns_down, | |
| cross_attns_up): | |
| assert resnet_block_depth >= 1 | |
| self.downs.append(mlist([ | |
| ResnetBlock(dim_in, dim_out, timestep_cond_dim=timestep_cond_dim), | |
| mlist([ResnetBlock(dim_out, dim_out) for _ in range(resnet_block_depth)]), | |
| SpatioTemporalAttention(dim=dim_out, **attn_kwargs) if self_attend else None, | |
| Downsample(dim_out, downsample_time=compress_time), | |
| cross_attns_d if exists(cross_attns_d) else None | |
| ])) | |
| self.ups.append(mlist([ | |
| ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim=timestep_cond_dim), | |
| mlist( | |
| [ResnetBlock(dim_in + (dim_out if ind == 0 else 0), dim_in) for ind in range(resnet_block_depth)]), | |
| SpatioTemporalAttention(dim=dim_in, **attn_kwargs) if self_attend else None, | |
| Upsample(dim_out, upsample_time=compress_time), | |
| cross_attns_u if exists(cross_attns_u) else None | |
| ])) | |
| self.skip_scale = 2 ** -0.5 # paper shows faster convergence | |
| self.conv_in = PseudoConv3d(dim=channels, dim_out=dim, kernel_size=7, temporal_kernel_size=3) | |
| self.conv_out = PseudoConv3d(dim=dim, dim_out=channels, kernel_size=3, temporal_kernel_size=3) | |
| def forward( | |
| self, | |
| x, | |
| clip_vae_embed, | |
| timestep=None, | |
| enable_time=True | |
| ): | |
| assert not (exists(self.to_timestep_cond) ^ exists(timestep)) | |
| is_video = x.ndim == 5 | |
| if enable_time and is_video: | |
| frames = x.shape[2] | |
| assert divisible_by(frames, | |
| self.frame_multiple), f'number of frames on the video ({frames}) must be divisible by the frame multiple ({self.frame_multiple})' | |
| height, width = x.shape[-2:] | |
| assert divisible_by(height, self.image_size_multiple) and divisible_by(width, | |
| self.image_size_multiple), f'height and width of the image or video must be a multiple of {self.image_size_multiple}' | |
| # main logic | |
| t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None | |
| x = self.conv_in(x, enable_time=enable_time) | |
| hiddens = [] | |
| for init_block, blocks, maybe_attention, downsample, cross_attn in self.downs: | |
| x = init_block(x, t, enable_time=enable_time) | |
| hiddens.append(x.clone()) | |
| for block in blocks: | |
| x = block(x, enable_time=enable_time) | |
| if exists(maybe_attention): | |
| x = maybe_attention(x, enable_time=enable_time) # only happens in the last layer | |
| hiddens.append(x.clone()) | |
| x = downsample(x, enable_time=enable_time) | |
| if exists(cross_attn): | |
| x = cross_attn(x, clip_vae_embed) | |
| x = self.mid_block1(x, t, enable_time=enable_time) | |
| x = self.mid_attn(x, enable_time=enable_time) | |
| x = self.mid_block2(x, t, enable_time=enable_time) | |
| for init_block, blocks, maybe_attention, upsample, cross_attn in reversed(self.ups): | |
| x = upsample(x, enable_time=enable_time) | |
| x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1) | |
| x = init_block(x, t, enable_time=enable_time) | |
| x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1) | |
| for block in blocks: | |
| x = block(x, enable_time=enable_time) | |
| if exists(maybe_attention): | |
| x = maybe_attention(x, enable_time=enable_time) | |
| if exists(cross_attn): | |
| x = cross_attn(x, clip_vae_embed) | |
| x = self.conv_out(x, enable_time=enable_time) | |
| return x | |
| if __name__ == '__main__': | |
| Net = SpaceTimeUnet( | |
| dim=64, | |
| channels=3, | |
| dim_mult=(1, 2, 4, 8), | |
| temporal_compression=(False, False, False, True), | |
| self_attns=(False, False, False, True), | |
| condition_on_timestep=False) | |
| x = torch.randn([1,8,3,32,32]) | |
| sample_output = Net(x.permute(0, 2, 1, 3, 4)) | |