import torch from torch import nn import einops from typing import Tuple import random import numpy as np from tqdm import tqdm from .modules import DuoFrameDownEncoder,Upsampler,MapConv,MotionDownEncoder from .loss import l1,l2 from .transformer import (MotionTransformer, AMDDiffusionTransformerModel, MotionEncoderLearnTokenTransformer, AMDReconstructTransformerModel, AMDDiffusionTransformerModelDualStream, AMDDiffusionTransformerModelImgSpatial, AMDDiffusionTransformerModelImgSpatialDoubleRef, AMDReconstructTransformerModelSpatial) from .rectified_flow import RectifiedFlow from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.models.resnet import ResnetBlock2D import einops import torch.nn.functional as F from diffusers.utils import export_to_gif class AMDModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__(self, image_inchannel :int = 4, image_height :int = 32, image_width :int = 32, video_frames :int = 16, scheduler_num_step :int = 1000, # ----------- MotionEncoder ----------- motion_token_num:int = 12, motion_token_channel: int = 128, enc_num_layers:int = 8, enc_nhead:int = 8, enc_ndim:int = 64, enc_dropout:float = 0.0, motion_need_norm_out:bool = False, # ----------- MotionTransformer --------- need_motion_transformer :bool = False, motion_transformer_attn_head_dim:int = 64, motion_transformer_attn_num_heads:int = 16, motion_transformer_num_layers:int = 4, # ----------- Diffusion Transformer ----------- diffusion_model_type : str = 'default', # or dual diffusion_attn_head_dim : int = 64, diffusion_attn_num_heads : int = 16, diffusion_out_channels : int = 4, diffusion_num_layers : int = 16, image_patch_size : int = 2, motion_patch_size : int = 1, motion_drop_ratio: float = 0.0, refimg_drop: bool = False, # ----------- Sample -------------- extract_motion_with_motion_transformer = False, **kwargs, ): super().__init__() # setting self.num_step = scheduler_num_step self.scheduler = RectifiedFlow(num_steps=scheduler_num_step) self.need_motion_transformer = need_motion_transformer self.extract_motion_with_motion_transformer = extract_motion_with_motion_transformer self.diffusion_model_type = diffusion_model_type self.target_frame = video_frames self.refimg_drop = refimg_drop # motion Encoder self.motion_encoder = MotionEncoderLearnTokenTransformer(img_height = image_height, img_width=image_width, img_inchannel=image_inchannel, img_patch_size = image_patch_size, motion_token_num = motion_token_num, motion_channel = motion_token_channel, need_norm_out = motion_need_norm_out, # ----- attention num_attention_heads=enc_nhead, attention_head_dim=enc_ndim, num_layers=enc_num_layers, dropout=enc_dropout, attention_bias= True,) # motion transformer if need_motion_transformer: self.motion_transformer = MotionTransformer(motion_token_num=motion_token_num, motion_token_channel=motion_token_channel, attention_head_dim=motion_transformer_attn_head_dim, num_attention_heads=motion_transformer_attn_num_heads, num_layers=motion_transformer_num_layers,) # diffusion transformer if diffusion_model_type == 'default': dit_image_inchannel = image_inchannel * 2 # zi + zt self.diffusion_transformer = AMDDiffusionTransformerModel(num_attention_heads= diffusion_attn_num_heads, attention_head_dim= diffusion_attn_head_dim, out_channels = diffusion_out_channels, num_layers= diffusion_num_layers, # ----- img image_width= image_width, image_height= image_height, image_patch_size= image_patch_size, image_in_channels = dit_image_inchannel, # ----- motion motion_token_num = motion_token_num, motion_in_channels = motion_token_channel,) elif diffusion_model_type == 'dual': dit_image_inchannel = image_inchannel * 2 # zi + zt self.diffusion_transformer = AMDDiffusionTransformerModelDualStream(num_attention_heads= diffusion_attn_num_heads, attention_head_dim= diffusion_attn_head_dim, out_channels = diffusion_out_channels, num_layers= diffusion_num_layers, # ----- img image_width= image_width, image_height= image_height, image_patch_size= image_patch_size, image_in_channels = dit_image_inchannel, # ----- motion motion_token_num = motion_token_num, motion_in_channels = motion_token_channel, motion_target_num_frame = video_frames) elif diffusion_model_type == 'spatial': dit_image_inchannel = image_inchannel * 2 # zi + zt self.diffusion_transformer = AMDDiffusionTransformerModelImgSpatial(num_attention_heads= diffusion_attn_num_heads, attention_head_dim= diffusion_attn_head_dim, out_channels = diffusion_out_channels, num_layers= diffusion_num_layers, # ----- img image_width= image_width, image_height= image_height, image_patch_size= image_patch_size, image_in_channels = dit_image_inchannel, # ----- motion motion_token_num = motion_token_num, motion_in_channels = motion_token_channel, motion_target_num_frame = video_frames) elif diffusion_model_type == 'doubleref': dit_image_inchannel = image_inchannel self.diffusion_transformer = AMDDiffusionTransformerModelImgSpatialDoubleRef(num_attention_heads= diffusion_attn_num_heads, attention_head_dim= diffusion_attn_head_dim, out_channels = diffusion_out_channels, num_layers= diffusion_num_layers, # ----- img image_width= image_width, image_height= image_height, image_patch_size= image_patch_size, image_in_channels = dit_image_inchannel, # ----- motion motion_token_num = motion_token_num, motion_in_channels = motion_token_channel, motion_target_num_frame = video_frames) else: raise IndexError def forward(self, video:torch.tensor, ref_img:torch.Tensor , randomref_img:torch.Tensor = None, time_step:torch.tensor = None, return_meta_info=False, mask_ratio=None, **kwargs,): """ Args: video: (N,T,C,H,W) ref_img: (N,T,C,H,W) randomref_img : (N,T,C,H,W) """ device = video.device n,t,c,h,w = video.shape assert video.shape == ref_img.shape ,f'video.shape:{video.shape}should be equal to ref_img.shape:{ref_img.shape}' if self.diffusion_model_type == 'doubleref' : assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given" # motion encoder if mask_ratio is not None: mask_ratio = torch.rand(1).item() * mask_ratio if self.diffusion_model_type == 'doubleref' and randomref_img is not None: if randomref_img.dim()==4: randomref_img = randomref_img.unsqueeze(1).repeat(1,t,1,1,1) refimg_and_video = torch.cat([randomref_img,video],dim=1)# (n,t+t,C,H,W) else: refimg_and_video = torch.cat([ref_img,video],dim=1)# (n,t+t,C,H,W) motion = self.motion_encoder(refimg_and_video,mask_ratio) # (n,t+t,l,d) source_motion = motion[:,:t].flatten(0,1) # (NT,motion_token,d) target_motion = motion[:,t:].flatten(0,1) # (NT,motion_token,d) assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' # motion transformer if self.need_motion_transformer: target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) target_motion = self.motion_transformer(target_motion) target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) # prepare for Diffusion Transformer zi = ref_img.flatten(0,1) # (NT,C,H,W) zj = video.flatten(0,1) # (NT,C,H,W) if self.diffusion_model_type == 'doubleref' and randomref_img is not None: randomref_img = randomref_img.flatten(0,1) # (NT,C,H,W) if time_step is None: time_step = self.prepare_timestep(batch_size= zj.shape[0],device= device) #(b,) if self.diffusion_model_type != 'default': time_step = self.prepare_timestep(batch_size= n,device= device) # (n,) time_step = time_step.repeat_interleave(t) # (b,) zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) # (NT,C,H,W),(NT,C,H,W) # dit forward if self.refimg_drop: zi = torch.zeros_like(zi).to(video.device) image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W) pre = self.diffusion_transformer(motion_source_hidden_states = source_motion, motion_target_hidden_states = target_motion, image_hidden_states = image_hidden_states, randomref_image_hidden_states = randomref_img, timestep = time_step,) # loss diff_loss = l2(pre,vel) rec_zj = self.scheduler.get_target_with_zt_vel(zt,pre,time_step) rec_loss = l2(rec_zj,zj) loss = diff_loss loss_dict = {'loss':loss,'diff_loss':diff_loss,'rec_loss':rec_loss} if return_meta_info: return {'motion' : motion, # (,t,motion_out_channels,h,w) , output of motion transformer 'zi' : zi, # (b,C,H,W) | b = n * t 'zj' : zj, # (b,C,H,W) 'zt' : zt, # (b,C,H,W) 'gt' : vel, # (b,C,H,W) 'pre': pre, # (b,C,H,W) 'time_step': time_step, # (b,) } else: return pre,vel,loss_dict # (b,C,H,W) def get_noise_latent_pair(self, video:torch.Tensor, ref_img:torch.Tensor , randomref_img:torch.Tensor, sample_step:int = 50, ): pass @torch.no_grad() def sample(self,video:torch.Tensor, ref_img:torch.Tensor , randomref_img:torch.Tensor = None, sample_step:int = 50, mask_ratio = None, start_step:int = None, return_meta_info=False, **kwargs,): device = video.device n,t,c,h,w = video.shape if start_step is None: start_step = self.scheduler.num_step assert start_step <= self.scheduler.num_step , 'start_step cant be larger than scheduler.num_step' if self.diffusion_model_type == 'doubleref' : assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given" if ref_img.dim()==4: ref_img = ref_img.unsqueeze(1).repeat(1,t,1,1,1) # motion encoder if mask_ratio is not None: print(f'* Sampling with Mask_Ratio = {mask_ratio}') mask_ratio = mask_ratio if self.diffusion_model_type == 'doubleref' and randomref_img is not None: if randomref_img.dim()==4: randomref_img = randomref_img.unsqueeze(1).repeat(1,t,1,1,1) refimg_and_video = torch.cat([randomref_img,video],dim=1)# (n,t+t,C,H,W) else: refimg_and_video = torch.cat([ref_img,video],dim=1)# (n,t+t,C,H,W) motion = self.motion_encoder(refimg_and_video,mask_ratio) # (n,t+t,motion_out_channels,h,w) source_motion = motion[:,:t].flatten(0,1) # (NT,motion_token,d) target_motion = motion[:,t:].flatten(0,1) # (NT,motion_token,d) assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' # motion transformer if self.need_motion_transformer: target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) target_motion = self.motion_transformer(target_motion) target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) # prepare for Diffusion Transformer time_step = torch.ones((source_motion.shape[0],)).to(device) time_step = time_step * start_step zi = ref_img.flatten(0,1) # (NT,C,H,W) zj = video.flatten(0,1) # (NT,C,H,W) if self.diffusion_model_type == 'doubleref' and randomref_img is not None: randomref_img = randomref_img.flatten(0,1) # (NT,C,H,W) zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) # (NT,C,H,W),(NT,C,H,W) noise = zj - vel # Sample Loop pre_cache = [] sample_cache = [] # 1.step_seq step_seq = np.linspace(0, start_step, num=sample_step+1, endpoint=True,dtype=int) # [0,5,10,15,....,start_step] step_seq = list(reversed(step_seq[1:])) # delete step:0 [start_step,.....,15,10,5] # 2.Euler step dt = 1./sample_step if self.refimg_drop: zi = torch.zeros_like(zi).to(video.device) for i in tqdm(step_seq): # time_step time_step = torch.ones((zt.shape[0],)).to(zt.device) time_step = time_step * i # input zt = zt.to(video.dtype) image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W) # forward pre = self.diffusion_transformer(motion_source_hidden_states = source_motion, motion_target_hidden_states = target_motion, image_hidden_states = image_hidden_states, randomref_image_hidden_states = randomref_img, timestep = time_step,) zt = zt + pre * dt pre_cache.append(pre) sample_cache.append(zt) zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n) zt = einops.rearrange(zt,'(n t) c h w -> n t c h w',n=n) zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n) if return_meta_info: return {'zi' : zi, # (b,1,c,h,w) 'zj' : zj, # (b,1,c,h,w) 'sample' : zt, # (b,1,c,h,w) 'pre_cache' : pre_cache, # [(b,c,h,w),....] 'sample_cache' : sample_cache, # [(b,c,h,w),....] 'step_seq' : step_seq, 'motion' : target_motion, # (b,C,H,W), "noise" : noise } else: return zi,zt,zj # (n,t,c,h,w) @torch.no_grad() def sample_with_refimg_motion(self, ref_img:torch.Tensor, motion=torch.Tensor, randomref_img:torch.Tensor = None, sample_step:int = 10, mask_ratio = None, return_meta_info=False, **kwargs,): """ Args: ref_img : (N,C,H,W) randomref_img : (N,C,H,W) motion : (N,F,L,D) Return: video : (N,T,C,H,W) """ device = motion.device n,t,l,d = motion.shape start_step = self.scheduler.num_step # motion encoder refimg = ref_img.unsqueeze(1) # (N,1,C,H,W) if self.diffusion_model_type == 'doubleref' : assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given" if self.diffusion_model_type == 'doubleref' and randomref_img is not None: print('* Warnning * diffusion_model_type:doubleref') if randomref_img.dim()==4: randomref_img = randomref_img.unsqueeze(1) # (N,1,C,H,W) source_motion = self.motion_encoder(randomref_img,mask_ratio) # (n,1,motion_token,d) else: source_motion = self.motion_encoder(refimg,mask_ratio) # (n,1,motion_token,d) source_motion = source_motion.repeat(1,t,1,1).flatten(0,1) # (NT,l,d) target_motion = motion.flatten(0,1) # (NT,l,d) assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' # motion transformer if self.need_motion_transformer and not self.extract_motion_with_motion_transformer: target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) target_motion = self.motion_transformer(target_motion) target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) # prepare for Diffusion Transformer time_step = torch.ones((source_motion.shape[0],)).to(device) time_step = time_step * start_step zi = refimg.repeat(1,t,1,1,1).flatten(0,1) # (NT,C,H,W) zj = zi if self.diffusion_model_type == 'doubleref' and randomref_img is not None: randomref_img = randomref_img.repeat(1,t,1,1,1) randomref_img = randomref_img.flatten(0,1) # (NT,C,H,W) # zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) # (NT,C,H,W),(NT,C,H,W) zt = torch.randn_like(zj) # Sample Loop pre_cache = [] sample_cache = [] # 1.step_seq step_seq = np.linspace(0, start_step, num=sample_step+1, endpoint=True,dtype=int) # [0,5,10,15,....,start_step] step_seq = list(reversed(step_seq[1:])) # delete step:0 [start_step,.....,15,10,5] # 2.Euler step dt = 1./sample_step if self.refimg_drop: zi = torch.zeros_like(zi).to(ref_img.device) for i in tqdm(step_seq): # time_step time_step = torch.ones((zt.shape[0],)).to(zt.device) time_step = time_step * i # input zt = zt.to(ref_img.dtype) image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W) # forward pre = self.diffusion_transformer(motion_source_hidden_states = source_motion, motion_target_hidden_states = target_motion, image_hidden_states = image_hidden_states, randomref_image_hidden_states = randomref_img, timestep = time_step,) zt = zt + pre * dt # unsqueeze (n,1,c,h,w) means images, (n,t,c,h,w) means video t>1 . zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n,t=t) zt = einops.rearrange(zt,'(n t) c h w -> n t c h w',n=n,t=t) if return_meta_info: return {'zi' : zi, # (b,1,c,h,w) 'zj' : zj, # (b,1,c,h,w) 'sample' : zt, # (b,1,c,h,w) 'pre_cache' : pre_cache, # [(b,c,h,w),....] 'sample_cache' : sample_cache, # [(b,c,h,w),....] 'step_seq' : step_seq, 'motion' : target_motion, # (b,C,H,W) } else: return zi,zt,zj # (b,1,c,h,w) def extract_motion(self,video:torch.tensor,mask_ratio=None): # video : (N,T,C,H,W) n,t,c,h,w = video.shape motion = self.motion_encoder(video,mask_ratio) # (N,T,L,D) if self.need_motion_transformer and self.extract_motion_with_motion_transformer: motion = self.motion_transformer(motion) # (N,T,L,D) return motion def prepare_timestep(self,batch_size:int,device,time_step = None): if time_step is not None: return time_step.to(device) else: return torch.randint(0,self.num_step+1,(batch_size,)).to(device) def prepare_encoder_input(self,video:torch.tensor): assert len(video.shape) == 5 , f'only support video data : 5D tensor , but got {video.shape}' # cat pre = video[:,:-1,:,:,:] post= video[:,1:,:,:,:] duo_frame_mix = torch.cat([pre,post],dim=2) # (b,t-1,2c,h,w) duo_frame_mix = einops.rearrange(duo_frame_mix,'b t c h w -> (b t) c h w') return duo_frame_mix # (b*f-1,2c,h,w) def unpatchify(self, x ,patch_size): """ x: (N, S, patch_size**2 *C) imgs: (N, C, H, W) """ p = patch_size h = w = int(x.shape[1]**.5) # c = self.in_chans c = x.shape[2] // (p**2) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) # (N, h, w, p, p, c) x = torch.einsum('nhwpqc->nchpwq', x) # (N, c, h, p, w, p) imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return imgs #(N,C,H,W) def reset_infer_num_frame(self, num:int): old_num = self.diffusion_transformer.target_frame self.diffusion_transformer.target_frame = num print(f'* Reset infer frame from {old_num} to {self.diffusion_transformer.target_frame} *') class AMDModel_Rec(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__(self, image_inchannel :int = 4, image_height :int = 32, image_width :int = 32, video_frames :int = 16, scheduler_num_step :int = 1000, # ----------- MotionEncoder ----------- motion_token_num:int = 12, motion_token_channel: int = 128, enc_num_layers:int = 8, enc_nhead:int = 8, enc_ndim:int = 64, enc_dropout:float = 0.0, motion_need_norm_out:bool = True, # ----------- MotionTransformer --------- need_motion_transformer :bool = False, motion_transformer_attn_head_dim:int = 64, motion_transformer_attn_num_heads:int = 16, motion_transformer_num_layers:int = 4, # ----------- Diffusion Transformer ----------- diffusion_model_type : str = 'default', # or dual diffusion_attn_head_dim : int = 64, diffusion_attn_num_heads : int = 16, diffusion_out_channels : int = 4, diffusion_num_layers : int = 16, image_patch_size : int = 2, motion_patch_size : int = 1, motion_drop_ratio: float = 0.0, **kwargs, ): super().__init__() # setting self.num_step = scheduler_num_step self.scheduler = RectifiedFlow(num_steps=scheduler_num_step) self.need_motion_transformer = need_motion_transformer # zt token INIT_CONST = 0.02 self.zt_token = nn.Parameter(torch.randn(1, image_inchannel, image_height,image_width) * INIT_CONST) # motion Encoder self.motion_encoder = MotionEncoderLearnTokenTransformer(img_height = image_height, img_width=image_width, img_inchannel=image_inchannel, img_patch_size = image_patch_size, motion_token_num = motion_token_num, motion_channel = motion_token_channel, need_norm_out = motion_need_norm_out, # ----- attention num_attention_heads=enc_nhead, attention_head_dim=enc_ndim, num_layers=enc_num_layers, dropout=enc_dropout, attention_bias= True,) # motion transformer if need_motion_transformer: self.motion_transformer = MotionTransformer(motion_token_num=motion_token_num, motion_token_channel=motion_token_channel, attention_head_dim=motion_transformer_attn_head_dim, num_attention_heads=motion_transformer_attn_num_heads, num_layers=motion_transformer_num_layers,) # diffusion transformer if diffusion_model_type == 'default': dit_image_inchannel = image_inchannel * 2 # zi + zt self.transformer = AMDReconstructTransformerModel(num_attention_heads= diffusion_attn_num_heads, attention_head_dim= diffusion_attn_head_dim, out_channels = diffusion_out_channels, num_layers= diffusion_num_layers, # ----- img image_width= image_width, image_height= image_height, image_patch_size= image_patch_size, image_in_channels = dit_image_inchannel, # ----- motion motion_token_num = motion_token_num, motion_in_channels = motion_token_channel,) elif diffusion_model_type == 'spatial': dit_image_inchannel = image_inchannel * 2 # zi + zt self.transformer = AMDReconstructTransformerModelSpatial(num_attention_heads= diffusion_attn_num_heads, attention_head_dim= diffusion_attn_head_dim, out_channels = diffusion_out_channels, num_layers= diffusion_num_layers, # ----- img image_width= image_width, image_height= image_height, image_patch_size= image_patch_size, image_in_channels = dit_image_inchannel, # ----- motion motion_token_num = motion_token_num, motion_in_channels = motion_token_channel, motion_target_num_frame = video_frames) def forward(self, video:torch.tensor, ref_img:torch.Tensor , time_step:torch.tensor = None, return_meta_info=False, **kwargs,): """ Args: video: (N,T,C,H,W) ref_img: (N,T,C,H,W) """ device = video.device n,t,c,h,w = video.shape assert video.shape == ref_img.shape ,f'video.shape:{video.shape}should be equal to ref_img.shape:{ref_img.shape}' # motion encoder refimg_and_video = torch.cat([ref_img,video],dim=1)# (n,t+t,C,H,W) motion = self.motion_encoder(refimg_and_video) # (n,t+t,l,d) source_motion = motion[:,:t].flatten(0,1) # (NT,motion_token,d) target_motion = motion[:,t:].flatten(0,1) # (NT,motion_token,d) assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' # motion transformer if self.need_motion_transformer: target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) target_motion = self.motion_transformer(target_motion) target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) # prepare for Diffusion Transformer zi = ref_img.flatten(0,1) # (NT,C,H,W) zj = video.flatten(0,1) # (NT,C,H,W) zt = self.zt_token.repeat(zj.shape[0],1,1,1) # (NT,C,H,W) # dit forward image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W) pre = self.transformer(motion_source_hidden_states = source_motion, motion_target_hidden_states = target_motion, image_hidden_states = image_hidden_states,) # loss rec_loss = l2(pre,zj) loss = rec_loss loss_dict = {'loss':loss,'rec_loss':rec_loss} if return_meta_info: return {'motion' : motion, # (,t,motion_out_channels,h,w) , output of motion transformer 'zi' : zi, # (b,C,H,W) | b = n * t 'zj' : zj, # (b,C,H,W) 'zt' : zt, # (b,C,H,W) 'pre': pre, # (b,C,H,W) 'time_step': time_step, # (b,) } else: return pre,zj,loss_dict # (b,C,H,W) @torch.no_grad() def sample(self, video:torch.tensor, ref_img:torch.Tensor , sample_step:int = 50, start_step:int = None, return_meta_info=False, **kwargs,): device = video.device n,t,c,h,w = video.shape if start_step is None: start_step = self.scheduler.num_step assert start_step <= self.scheduler.num_step , 'start_step cant be larger than scheduler.num_step' # motion encoder refimg_and_video = torch.cat([ref_img,video],dim=1)# (n,t+t,C,H,W) motion = self.motion_encoder(refimg_and_video) # (n,t+t,motion_out_channels,h,w) source_motion = motion[:,:t].flatten(0,1) # (NT,motion_token,d) target_motion = motion[:,t:].flatten(0,1) # (NT,motion_token,d) assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' # motion transformer if self.need_motion_transformer: target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) target_motion = self.motion_transformer(target_motion) target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) zi = ref_img.flatten(0,1) # (NT,C,H,W) zj = video.flatten(0,1) # (NT,C,H,W) zt = self.zt_token.repeat(zj.shape[0],1,1,1) # (NT,C,H,W) # input zt = zt.to(video.dtype) image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W) # forward pre = self.transformer(motion_source_hidden_states = source_motion, motion_target_hidden_states = target_motion, image_hidden_states = image_hidden_states,) zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n) zt = einops.rearrange(pre,'(n t) c h w -> n t c h w',n=n) zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n) if return_meta_info: return {'zi' : zi, # (b,1,c,h,w) 'zj' : zj, # (b,1,c,h,w) } else: return zi,zt,zj # (n,t,c,h,w) def sample_with_refimg_motion(self, ref_img:torch.Tensor, motion=torch.Tensor, sample_step:int = 10, return_meta_info=False, **kwargs,): """ Args: ref_img : (N,C,H,W) motion : (N,F,L,D) Return: video : (N,T,C,H,W) """ device = motion.device n,t,l,d = motion.shape start_step = self.scheduler.num_step # motion encoder refimg = ref_img.unsqueeze(1) # (N,1,C,H,W) source_motion = self.motion_encoder(refimg) # (n,1,motion_token,d) source_motion = source_motion.repeat(1,t,1,1).flatten(0,1) # (NT,l,d) target_motion = motion.flatten(0,1) # (NT,l,d) assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' # motion transformer if self.need_motion_transformer: target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) target_motion = self.motion_transformer(target_motion) target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) # prepare for Diffusion Transformer time_step = torch.ones((source_motion.shape[0],)).to(device) time_step = time_step * start_step zi = refimg.repeat(1,t,1,1,1).flatten(0,1) # (NT,C,H,W) zj = zi zt = self.zt_token.repeat(zj.shape[0],1,1,1) # (NT,C,H,W) # input zt = zt.to(zj.dtype) image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W) # forward pre = self.transformer(motion_source_hidden_states = source_motion, motion_target_hidden_states = target_motion, image_hidden_states = image_hidden_states,) zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n) zt = einops.rearrange(pre,'(n t) c h w -> n t c h w',n=n) zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n) if return_meta_info: return {'zi' : zi, # (b,1,c,h,w) 'zj' : zj, # (b,1,c,h,w) } else: return zi,zt,zj # (b,1,c,h,w) def extract_motion(self,video:torch.tensor): # video : (N,T,C,H,W) # motion Encoder motion = self.motion_encoder(video) # (N,T,L,D) if self.need_motion_transformer: motion = self.motion_transformer(motion) # (N,T,L,D) return motion def AMD_S(**kwargs) -> AMDModel: return AMDModel( # ----------- motion encoder ----------- enc_num_layers = 8, enc_nhead = 8, enc_ndim = 64, # ----------- Diffusion Transformer ----------- diffusion_attn_head_dim = 64, diffusion_attn_num_heads = 16, diffusion_out_channels = 4, diffusion_num_layers = 12, **kwargs) def AMD_L(**kwargs) -> AMDModel: return AMDModel( # ----------- motion encoder ----------- enc_num_layers = 8, enc_nhead = 16, enc_ndim = 64, # ----------- Diffusion Transformer ----------- diffusion_attn_head_dim = 96, diffusion_attn_num_heads = 16, diffusion_out_channels = 4, diffusion_num_layers = 16, **kwargs) def AMD_S_Rec(**kwargs) -> AMDModel: return AMDModel_Rec( # ----------- motion encoder ----------- enc_num_layers = 8, enc_nhead = 8, enc_ndim = 64, # ----------- Diffusion Transformer ----------- diffusion_attn_head_dim = 64, diffusion_attn_num_heads = 16, diffusion_out_channels = 4, diffusion_num_layers = 12, **kwargs) def AMD_S_RecSplit(**kwargs) -> AMDModel: return AMDModel_Rec( # ----------- motion encoder ----------- enc_num_layers = 8, enc_nhead = 8, enc_ndim = 64, # ----------- Diffusion Transformer ----------- diffusion_attn_head_dim = 64, diffusion_attn_num_heads = 16, diffusion_out_channels = 4, diffusion_num_layers = 12, is_split = True, **kwargs) AMD_models = { "AMD_S": AMD_S, # 250M "AMD_L": AMD_L, # 700M "AMD_S_Rec": AMD_S_Rec, # 250M "AMD_S_RecSplit" : AMD_S_RecSplit, # 250M } # S 206 B 333 M 642 L 1053