import math import torch import torch.nn as nn from src.models.autoencoder import AutoencoderKL from src.modules.ae_modules import Normalize, nonlinearity from src.modules.attention_temporal_videoae import * from src.modules.t5 import T5Embedder from src.distributions import DiagonalGaussianDistribution from src.models.autoencoder_temporal import EncoderTemporal1DCNN, DecoderTemporal1DCNN try: import xformers import xformers.ops as xops XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False class TemporalConvLayer(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.norm = Normalize(in_channels) self.conv = torch.nn.Conv3d( in_channels, out_channels, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1), ) nn.init.constant_(self.conv.weight, 0) nn.init.constant_(self.conv.bias, 0) def forward(self, x): h = x h = self.norm(h) h = nonlinearity(h) h = self.conv(h) return h class ResnetBlock2plus1D(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, kernel_size_t=3, padding_t=1, stride_t=1, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv3d( in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), ) self.conv1_tmp = TemporalConvLayer(out_channels, out_channels) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv3d( out_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), ) self.conv2_tmp = TemporalConvLayer(out_channels, out_channels) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv3d( in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), ) else: self.nin_shortcut = torch.nn.Conv3d( in_channels, out_channels, kernel_size=(1, 1, 1), stride=1, padding=(0, 0, 0), ) self.conv3_tmp = TemporalConvLayer(out_channels, out_channels) def forward(self, x, temb, mask_temporal=False): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) if not mask_temporal: h = self.conv1_tmp(h) + h if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if not mask_temporal: h = self.conv2_tmp(h) + h # skip connections if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) if not mask_temporal: x = self.conv3_tmp(x) + x return x + h class AttnBlock3D(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv3d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv3d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv3d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv3d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, t, h, w = q.shape # q = q.reshape(b,c,h*w) # bcl # q = q.permute(0,2,1) # bcl -> blc l=hw # k = k.reshape(b,c,h*w) # bcl q = rearrange(q, "b c t h w -> (b t) (h w) c") # blc k = rearrange(k, "b c t h w -> (b t) c (h w)") # bcl w_ = torch.bmm(q, k) # b,l,l w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # v = v.reshape(b,c,h*w) v = rearrange(v, "b c t h w -> (b t) c (h w)") # bcl # attend to values w_ = w_.permute(0, 2, 1) # bll h_ = torch.bmm(v, w_) # bcl # h_ = h_.reshape(b,c,h,w) h_ = rearrange(h_, "(b t) c (h w) -> b c t h w", b=b, h=h) h_ = self.proj_out(h_) return x + h_ # --------------------------------------------------------------------------------------------------- class CrossAttention(nn.Module): def __init__( self, query_dim, patch_size=1, context_dim=None, heads=8, dim_head=64, dropout=0.0, ): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head**-0.5 self.heads = heads self.dim_head = dim_head self.patch_size = patch_size patch_dim = query_dim * patch_size * patch_size self.norm = nn.LayerNorm(patch_dim) self.to_q = nn.Linear(patch_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, patch_dim), nn.Dropout(dropout) ) self.attention_op: Optional[Any] = None def forward(self, x, context=None, mask=None): b, c, t, height, width = x.shape # patch: [patch_size, patch_size] divide_factor_height = height // self.patch_size divide_factor_width = width // self.patch_size x = rearrange( x, "b c t (df1 ph) (df2 pw) -> (b t) (df1 df2) (ph pw c)", df1=divide_factor_height, df2=divide_factor_width, ph=self.patch_size, pw=self.patch_size, ) x = self.norm(x) context = default(context, x) context = repeat(context, "b n d -> (b t) n d", b=b, t=t) q = self.to_q(x) k = self.to_k(context) v = self.to_v(context) q, k, v = map( lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.heads), (q, k, v) ) if exists(mask): mask = rearrange(mask, "b ... -> b (...)") mask = repeat(mask, "b j -> (b t h) () j", t=t, h=self.heads) if XFORMERS_IS_AVAILBLE: if exists(mask): mask = mask.to(q.dtype) max_neg_value = -torch.finfo(q.dtype).max attn_bias = torch.zeros_like(mask) attn_bias.masked_fill_(mask <= 0.5, max_neg_value) mask = mask.detach().cpu() attn_bias = attn_bias.expand(-1, q.shape[1], -1) attn_bias_expansion_q = (attn_bias.shape[1] + 7) // 8 * 8 attn_bias_expansion_k = (attn_bias.shape[2] + 7) // 8 * 8 attn_bias_expansion = torch.zeros( (attn_bias.shape[0], attn_bias_expansion_q, attn_bias_expansion_k), dtype=attn_bias.dtype, device=attn_bias.device, ) attn_bias_expansion[:, : attn_bias.shape[1], : attn_bias.shape[2]] = ( attn_bias ) attn_bias = attn_bias.detach().cpu() out = xops.memory_efficient_attention( q, k, v, attn_bias=attn_bias_expansion[ :, : attn_bias.shape[1], : attn_bias.shape[2] ], scale=self.scale, ) else: out = xops.memory_efficient_attention(q, k, v, scale=self.scale) else: sim = einsum("b i d, b j d -> b i j", q, k) * self.scale if exists(mask): max_neg_value = -torch.finfo(sim.dtype).max sim.masked_fill_(~(mask > 0.5), max_neg_value) attn = sim.softmax(dim=-1) out = einsum("b i j, b j d -> b i d", attn, v) out = rearrange(out, "(b h) n d -> b n (h d)", h=self.heads) ret = self.to_out(out) ret = rearrange( ret, "(b t) (df1 df2) (ph pw c) -> b c t (df1 ph) (df2 pw)", b=b, t=t, df1=divide_factor_height, df2=divide_factor_width, ph=self.patch_size, pw=self.patch_size, ) return ret # --------------------------------------------------------------------------------------------------- class TemporalAttention(nn.Module): def __init__( self, channels, num_heads=1, num_head_channels=-1, max_temporal_length=64, ): """ a clean multi-head temporal attention """ super().__init__() if num_head_channels == -1: self.num_heads = num_heads else: assert ( channels % num_head_channels == 0 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.norm = normalization(channels) self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1)) self.attention = QKVAttention(self.num_heads) self.relative_position_k = RelativePosition( num_units=channels // self.num_heads, max_relative_position=max_temporal_length, ) self.relative_position_v = RelativePosition( num_units=channels // self.num_heads, max_relative_position=max_temporal_length, ) self.proj_out = zero_module( conv_nd(1, channels, channels, 1) ) # conv_dim, in_channels, out_channels, kernel_size def forward(self, x, mask=None): b, c, t, h, w = x.shape out = rearrange(x, "b c t h w -> (b h w) c t") qkv = self.qkv(self.norm(out)) len_q = qkv.size()[-1] len_k, len_v = len_q, len_q k_rp = self.relative_position_k(len_q, len_k) v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim] out = self.attention(qkv, rp=(k_rp, v_rp)) out = self.proj_out(out) out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w) return x + out # --------------------------------------------------------------------------------------------------- class Downsample2plus1D(nn.Module): """spatial downsample, in a factorized way""" def __init__(self, in_channels, with_conv, temp_down): super().__init__() self.with_conv = with_conv self.in_channels = in_channels self.temp_down = temp_down if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv3d( in_channels, in_channels, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=0, ) def forward(self, x, mask_temporal): if self.with_conv: pad = (0, 1, 0, 1, 0, 0) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) # print(f'[Encoder-Downsample] after conv={x.shape}') # print(f'[Encoder-Downsample] after conv_tmp={x.shape}') else: raise NotImplementedError # x = torch.nn.functional.avg_pool3d(x, kernel_size=2, stride=2) return x class Upsample2plus1D(nn.Module): def __init__(self, in_channels, with_conv, temp_up): super().__init__() self.with_conv = with_conv self.in_channels = in_channels self.temp_up = temp_up if self.with_conv: self.conv = torch.nn.Conv3d( in_channels, in_channels, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), ) def forward(self, x, mask_temporal): # print(f'[Decoder-Upsample] input={x.shape}') if self.temp_up and not mask_temporal: # import pdb;pdb.set_trace() x = torch.nn.functional.interpolate( x, scale_factor=(2.0, 2.0, 2.0), mode="nearest" ) else: t = x.shape[2] x = rearrange(x, "b c t h w -> b (c t) h w") x = torch.nn.functional.interpolate( x, scale_factor=(2.0, 2.0), mode="nearest" ) x = rearrange(x, "b (c t) h w -> b c t h w", t=t) if self.with_conv: x = self.conv(x) return x # --------------------------------------------------------------------------------------------------- class Encoder2plus1D(nn.Module): def __init__( self, *, ch, out_ch, temporal_down_factor, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", mask_temporal=False, **ignore_kwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) # spatial resolutions self.n_temporal_down = int( math.log2(temporal_down_factor) ) # temporal resolutions self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = torch.nn.Conv3d( in_channels, self.ch, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1) ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() cur_patch_size = 8 for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock2plus1D( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append( CrossAttention( query_dim=block_in, patch_size=cur_patch_size, context_dim=1024, ) ) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: temp_down = i_level <= self.n_temporal_down - 1 down.downsample = Downsample2plus1D( block_in, resamp_with_conv, temp_down ) curr_res = curr_res // 2 cur_patch_size //= 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock2plus1D( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = AttnBlock3D(block_in) self.mid.attn_1_tmp = TemporalAttention(block_in, num_heads=1) self.mid.block_2 = ResnetBlock2plus1D( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv3d( block_in, 2 * z_channels if double_z else z_channels, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), ) def forward( self, x, text_embeddings=None, text_attn_mask=None, mask_temporal=False ): # timestep embedding temb = None # print(f'[Encoder] input={x.shape}') # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb, mask_temporal) if len(self.down[i_level].attn) > 0: h = h + self.down[i_level].attn[i_block]( h, context=text_embeddings, mask=text_attn_mask ) # print(f'[Encoder] after down block={h.shape}') hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1], mask_temporal)) # middle h = hs[-1] h = self.mid.block_1(h, temb, mask_temporal) h = self.mid.attn_1(h) if not mask_temporal: h = self.mid.attn_1_tmp(h) h = self.mid.block_2(h, temb, mask_temporal) # print(f'[Encoder] after mid block = {h.shape}') # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) # print(f'[Encoder] after conv_out = {h.shape}') return h class Decoder2plus1D(nn.Module): def __init__( self, *, ch, out_ch, temporal_down_factor, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, attn_type="vanilla", mask_temporal=False, **ignorekwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) # spatial resolutions self.n_temporal_up = int( math.log2(temporal_down_factor) ) # temporal resolutions self.n_spatial_up = self.num_resolutions - 1 # 3 self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) # z to block_in self.conv_in = torch.nn.Conv3d( z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1) ) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock2plus1D( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = AttnBlock3D(block_in) self.mid.attn_1_tmp = TemporalAttention(block_in, num_heads=1) self.mid.block_2 = ResnetBlock2plus1D( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # print(f'[Decoder] mid block feature, temporal length={self.input_length//(2 ** self.num_resolutions)}') # upsampling self.up = nn.ModuleList() cur_patch_size = 1 for i_level in reversed(range(self.num_resolutions)): # 3210 block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( ResnetBlock2plus1D( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append( CrossAttention( query_dim=block_in, patch_size=cur_patch_size, context_dim=1024, ) ) up = nn.Module() up.block = block up.attn = attn if i_level != 0: temp_up = i_level <= self.num_resolutions - 1 - ( self.n_spatial_up - self.n_temporal_up ) up.upsample = Upsample2plus1D(block_in, resamp_with_conv, temp_up) curr_res = curr_res * 2 cur_patch_size *= 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv3d( block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1) ) def forward( self, z, text_embeddings=None, text_attn_mask=None, mask_temporal=False ): self.last_z_shape = z.shape # print(f'[Decoder] input={z.shape}') # timestep embedding temb = None # z to block_in h = self.conv_in(z) # print(f'[Decoder] after conv_in ={h.shape}') # middle h = self.mid.block_1(h, temb, mask_temporal) h = self.mid.attn_1(h) if not mask_temporal: h = self.mid.attn_1_tmp(h) h = self.mid.block_2(h, temb, mask_temporal) # print(f'[Decoder] after mid blocks ={h.shape}') # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb, mask_temporal) if len(self.up[i_level].attn) > 0: h = h + self.up[i_level].attn[i_block]( h, context=text_embeddings, mask=text_attn_mask ) # print(f'[Decoder] after up block ={h.shape}') if i_level != 0: h = self.up[i_level].upsample(h, mask_temporal) # print(f'[Decoder] after upsample ={h.shape}') # end if self.give_pre_end: return h h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) # print(f'[Decoder] after conv_out ={h.shape}') if self.tanh_out: h = torch.tanh(h) return h class AutoencoderKL2plus1D_1dcnn(AutoencoderKL): def __init__( self, ddconfig, ppconfig, lossconfig, embed_dim=0, use_quant_conv=True, test=False, ckpt_path=None, ckpt_path_2d=None, ckpt_path_4temporal=None, ignore_keys_3d=[], img_video_joint_train=False, video_key="", caption_guide=False, t5_model_max_length=120, *args, **kwargs, ): super(AutoencoderKL2plus1D_1dcnn, self).__init__( ddconfig, lossconfig, embed_dim, use_quant_conv, *args, test=False, **kwargs, ) self.img_video_joint_train = img_video_joint_train self.caption_guide = caption_guide self.video_key = video_key self.t5_model_max_length = t5_model_max_length self.use_quant_conv = use_quant_conv self.encoder_temporal = EncoderTemporal1DCNN(**ppconfig) self.decoder_temporal = DecoderTemporal1DCNN(**ppconfig) self.encoder = Encoder2plus1D(**ddconfig) self.decoder = Decoder2plus1D(**ddconfig) if use_quant_conv: assert embed_dim self.embed_dim = embed_dim self.quant_conv = torch.nn.Conv3d( 2 * ddconfig["z_channels"], 2 * embed_dim, 1 ) self.post_quant_conv = torch.nn.Conv3d(embed_dim, ddconfig["z_channels"], 1) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys_3d) if ckpt_path_2d is not None: self.init_from_2dckpt(ckpt_path_2d) if ckpt_path_4temporal is not None: self.init_from_4temporal(ckpt_path_4temporal, ignore_keys=ignore_keys_3d) if test: self.init_test() self.enable_text_embedder = False def init_from_2dckpt(self, ckpt_path_2d): sd = torch.load(ckpt_path_2d, map_location="cpu") try: sd = sd["state_dict"] except: pass sd_new = {} for k in sd.keys(): if k.startswith("first_stage_model."): new_key = k.split("first_stage_model.")[-1] # print(f"k={k},para={sd[k].shape}") v = sd[k] if v.dim() == 4: v = v.unsqueeze(2) sd_new[new_key] = v self.load_state_dict(sd_new, strict=False) print(f"Restored from {ckpt_path_2d}") def get_text_embeddings(self, captions): # print(f"caption is {captions}") # print(f"{self.device} enable T5?: {self.enable_text_embedder}") if not self.enable_text_embedder: self.enable_text_embedder = True self.text_embedder = T5Embedder( device=self.device, model_max_length=self.t5_model_max_length ) return self.text_embedder.get_text_embeddings(captions) def configure_optimizers(self): lr = self.learning_rate model_params = ( list(self.encoder_temporal.parameters()) + list(self.decoder_temporal.parameters()) + list(self.encoder.parameters()) + list(self.decoder.parameters()) ) if self.use_quant_conv: model_params += list(self.quant_conv.parameters()) + list( self.post_quant_conv.parameters() ) opt_ae = torch.optim.Adam(model_params, lr=lr, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam( self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) ) return [opt_ae, opt_disc], [] def encode_temporal(self, x, text_embeddings=None, text_attn_mask=None): # x: [b c t h w] h: [b c t//4 h w] # b = x.shape[0] moments = self.encoder_temporal(x, text_embeddings, text_attn_mask) posterior = DiagonalGaussianDistribution(moments) # posterior = rearrange(posterior, '(b t) c h w -> b c t h w', b=b) return posterior def decode_temporal(self, z, text_embeddings=None, text_attn_mask=None): # z: [b c t h w] dec: [b c t//4 h w] dec = self.decoder_temporal(z, text_embeddings, text_attn_mask) return dec def encode_2plus1d( self, x, text_embeddings=None, text_attn_mask=None, mask_temporal=False ): h = self.encoder( x, text_embeddings, text_attn_mask, mask_temporal=mask_temporal ) if self.use_quant_conv: h = self.quant_conv(h) posterior = DiagonalGaussianDistribution(h) return posterior def decode_2plus1d( self, z, text_embeddings=None, text_attn_mask=None, mask_temporal=False ): if self.use_quant_conv: z = self.post_quant_conv(z) dec = self.decoder( z, text_embeddings, text_attn_mask, mask_temporal=mask_temporal ) return dec def encode( self, x, text_embeddings=None, text_attn_mask=None, sample_posterior=True, mask_temporal=False, ): # [b, c, t, h, w] posterior = self.encode_2plus1d( x, text_embeddings, text_attn_mask, mask_temporal=mask_temporal ) if sample_posterior: z = posterior.sample() else: z = posterior.mode() z = z.to(device=self.device, dtype=self.dtype) if not mask_temporal: posterior = self.encode_temporal(z, text_embeddings, text_attn_mask) if sample_posterior: z = posterior.sample() else: z = posterior.mode() z = z.to(device=self.device, dtype=self.dtype) return z, posterior def decode(self, z, text_embeddings=None, text_attn_mask=None, mask_temporal=False): if not mask_temporal: z = self.decode_temporal(z, text_embeddings, text_attn_mask) dec = self.decode_2plus1d( z, text_embeddings, text_attn_mask, mask_temporal=mask_temporal ) return dec def forward( self, inputs, text_embeddings=None, text_attn_mask=None, sample_posterior=True, mask_temporal=False, ): # [b, c, t, h, w] input z, posterior = self.encode( inputs, text_embeddings, text_attn_mask, sample_posterior, mask_temporal=mask_temporal, ) dec = self.decode( z, text_embeddings, text_attn_mask, mask_temporal=mask_temporal ) return dec, posterior def training_step(self, *args): if self.img_video_joint_train: return self.training_step_joint(*args) else: return self.training_step_video(*args) def training_step_video(self, batch, batch_idx, optimizer_idx): inputs = self.get_input(batch, self.video_key) if self.caption_guide: text_embeddings, text_attn_mask = self.get_text_embeddings(batch["caption"]) reconstructions, posterior = self(inputs, text_embeddings, text_attn_mask) del text_embeddings, text_attn_mask else: reconstructions, posterior = self(inputs) if optimizer_idx == 0: # train encoder+decoder+logvar aeloss, log_dict_ae = self.loss( inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train", ) self.log( "aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) self.log_dict( log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False ) return aeloss if optimizer_idx == 1: # train the discriminator discloss, log_dict_disc = self.loss( inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train", ) self.log( "discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) self.log_dict( log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False ) return discloss def training_step_joint(self, batch, batch_idx, optimizer_idx): opt_ae, opt_disc = self.optimizers() inputs = self.get_input(batch, self.video_key) is_video = self.get_input(batch, "is_video") is_mask = not is_video.all() if not is_video.all(): # (b, c, 16, h, w) -> (16, c, 1, h, w) inputs = inputs.permute(2, 1, 0, 3, 4) if self.caption_guide: text_embeddings, text_attn_mask = self.get_text_embeddings(batch["caption"]) reconstructions, posterior = self( inputs, text_embeddings, text_attn_mask, mask_temporal=is_mask ) del text_embeddings, text_attn_mask else: reconstructions, posterior = self(inputs, mask_temporal=is_mask) # print(f"the temporal is masked: {is_mask}") if optimizer_idx == 0: # train encoder+decoder+logvar aeloss, log_dict_ae = self.loss( inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train", ) self.log( "aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) self.log_dict( log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False ) return aeloss if optimizer_idx == 1: # train the discriminator discloss, log_dict_disc = self.loss( inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train", ) self.log( "discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) self.log_dict( log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False ) return discloss def validation_step(self, batch, batch_idx): torch.cuda.empty_cache() inputs = self.get_input(batch, self.video_key) is_video = self.get_input(batch, "is_video") is_mask = not is_video.all() if self.caption_guide: text_embeddings, text_attn_mask = self.get_text_embeddings(batch["caption"]) reconstructions, posterior = self( inputs, text_embeddings, text_attn_mask, mask_temporal=is_mask ) del text_embeddings, text_attn_mask else: reconstructions, posterior = self(inputs, mask_temporal=is_mask) aeloss, log_dict_ae = self.loss( inputs, reconstructions, posterior, 0, self.global_step, last_layer=self.get_last_layer(), split="val", ) discloss, log_dict_disc = self.loss( inputs, reconstructions, posterior, 1, self.global_step, last_layer=self.get_last_layer(), split="val", ) del reconstructions self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) torch.cuda.empty_cache() return self.log_dict @torch.no_grad() def log_images_joint(self, batch, only_inputs=False, **kwargs): log = dict() x = self.get_input(batch, self.video_key) x = x.to(self.device) is_video = self.get_input(batch, "is_video") is_mask = not is_video.all() if not is_video.all(): # (b, c, 16, h, w) -> (16, c, 1, h, w) x = x.permute(2, 1, 0, 3, 4) if self.caption_guide: text_embeddings, text_attn_mask = self.get_text_embeddings(batch["caption"]) text_embeddings = text_embeddings.to(device=self.device, dtype=self.dtype) text_attn_mask = text_attn_mask.to(device=self.device, dtype=self.dtype) if not only_inputs: if self.caption_guide: xrec, posterior = self( x, text_embeddings, text_attn_mask, mask_temporal=is_mask ) else: xrec, posterior = self(x, mask_temporal=is_mask) if x.shape[1] > 3: # colorize with random projection assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) if self.caption_guide: log["samples"] = self.decode( torch.randn_like(posterior.sample()), text_embeddings, text_attn_mask, mask_temporal=is_mask, ) else: log["samples"] = self.decode( torch.randn_like(posterior.sample()), mask_temporal=is_mask ) xrec = xrec.cpu().detach() log["reconstructions"] = xrec x = x.cpu().detach() log["inputs"] = x return log @torch.no_grad() def log_video(self, batch, only_inputs=False, **kwargs): log = dict() x = self.get_input(batch, self.video_key) x = x.to(self.device, dtype=self.dtype) if self.caption_guide: text_embeddings, text_attn_mask = self.get_text_embeddings(batch["caption"]) text_embeddings = text_embeddings.to(device=self.device, dtype=self.dtype) text_attn_mask = text_attn_mask.to(device=self.device, dtype=self.dtype) if not only_inputs: if self.caption_guide: xrec, posterior = self(x, text_embeddings, text_attn_mask) else: xrec, posterior = self(x) if x.shape[1] > 3: # colorize with random projection assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) if self.caption_guide: log["samples"] = self.decode( torch.randn_like(posterior.sample()), text_embeddings, text_attn_mask, ) else: log["samples"] = self.decode(torch.randn_like(posterior.sample())) xrec = xrec.cpu().detach() log["reconstructions"] = xrec x = x.cpu().detach() log["inputs"] = x return log @torch.no_grad() def log_images(self, batch, only_inputs=False, **kwargs): if self.img_video_joint_train: return self.log_images_joint(batch, only_inputs=False, **kwargs) else: return self.log_video(batch, only_inputs=False, **kwargs)