semo / model /model_AMD.py
HappyP4nda's picture
Upload folder using huggingface_hub
bd546bf verified
import torch
from torch import nn
import einops
from typing import Tuple
import random
import numpy as np
from tqdm import tqdm
from .modules import DuoFrameDownEncoder,Upsampler,MapConv,MotionDownEncoder
from .loss import l1,l2
from .transformer import (MotionTransformer,
AMDDiffusionTransformerModel,
MotionEncoderLearnTokenTransformer,
AMDReconstructTransformerModel,
AMDDiffusionTransformerModelDualStream,
AMDDiffusionTransformerModelImgSpatial,
AMDDiffusionTransformerModelImgSpatialDoubleRef,
AMDReconstructTransformerModelSpatial)
from .rectified_flow import RectifiedFlow
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.resnet import ResnetBlock2D
import einops
import torch.nn.functional as F
from diffusers.utils import export_to_gif
class AMDModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(self,
image_inchannel :int = 4,
image_height :int = 32,
image_width :int = 32,
video_frames :int = 16,
scheduler_num_step :int = 1000,
# ----------- MotionEncoder -----------
motion_token_num:int = 12,
motion_token_channel: int = 128,
enc_num_layers:int = 8,
enc_nhead:int = 8,
enc_ndim:int = 64,
enc_dropout:float = 0.0,
motion_need_norm_out:bool = False,
# ----------- MotionTransformer ---------
need_motion_transformer :bool = False,
motion_transformer_attn_head_dim:int = 64,
motion_transformer_attn_num_heads:int = 16,
motion_transformer_num_layers:int = 4,
# ----------- Diffusion Transformer -----------
diffusion_model_type : str = 'default', # or dual
diffusion_attn_head_dim : int = 64,
diffusion_attn_num_heads : int = 16,
diffusion_out_channels : int = 4,
diffusion_num_layers : int = 16,
image_patch_size : int = 2,
motion_patch_size : int = 1,
motion_drop_ratio: float = 0.0,
refimg_drop: bool = False,
# ----------- Sample --------------
extract_motion_with_motion_transformer = False,
**kwargs,
):
super().__init__()
# setting
self.num_step = scheduler_num_step
self.scheduler = RectifiedFlow(num_steps=scheduler_num_step)
self.need_motion_transformer = need_motion_transformer
self.extract_motion_with_motion_transformer = extract_motion_with_motion_transformer
self.diffusion_model_type = diffusion_model_type
self.target_frame = video_frames
self.refimg_drop = refimg_drop
# motion Encoder
self.motion_encoder = MotionEncoderLearnTokenTransformer(img_height = image_height,
img_width=image_width,
img_inchannel=image_inchannel,
img_patch_size = image_patch_size,
motion_token_num = motion_token_num,
motion_channel = motion_token_channel,
need_norm_out = motion_need_norm_out,
# ----- attention
num_attention_heads=enc_nhead,
attention_head_dim=enc_ndim,
num_layers=enc_num_layers,
dropout=enc_dropout,
attention_bias= True,)
# motion transformer
if need_motion_transformer:
self.motion_transformer = MotionTransformer(motion_token_num=motion_token_num,
motion_token_channel=motion_token_channel,
attention_head_dim=motion_transformer_attn_head_dim,
num_attention_heads=motion_transformer_attn_num_heads,
num_layers=motion_transformer_num_layers,)
# diffusion transformer
if diffusion_model_type == 'default':
dit_image_inchannel = image_inchannel * 2 # zi + zt
self.diffusion_transformer = AMDDiffusionTransformerModel(num_attention_heads= diffusion_attn_num_heads,
attention_head_dim= diffusion_attn_head_dim,
out_channels = diffusion_out_channels,
num_layers= diffusion_num_layers,
# ----- img
image_width= image_width,
image_height= image_height,
image_patch_size= image_patch_size,
image_in_channels = dit_image_inchannel,
# ----- motion
motion_token_num = motion_token_num,
motion_in_channels = motion_token_channel,)
elif diffusion_model_type == 'dual':
dit_image_inchannel = image_inchannel * 2 # zi + zt
self.diffusion_transformer = AMDDiffusionTransformerModelDualStream(num_attention_heads= diffusion_attn_num_heads,
attention_head_dim= diffusion_attn_head_dim,
out_channels = diffusion_out_channels,
num_layers= diffusion_num_layers,
# ----- img
image_width= image_width,
image_height= image_height,
image_patch_size= image_patch_size,
image_in_channels = dit_image_inchannel,
# ----- motion
motion_token_num = motion_token_num,
motion_in_channels = motion_token_channel,
motion_target_num_frame = video_frames)
elif diffusion_model_type == 'spatial':
dit_image_inchannel = image_inchannel * 2 # zi + zt
self.diffusion_transformer = AMDDiffusionTransformerModelImgSpatial(num_attention_heads= diffusion_attn_num_heads,
attention_head_dim= diffusion_attn_head_dim,
out_channels = diffusion_out_channels,
num_layers= diffusion_num_layers,
# ----- img
image_width= image_width,
image_height= image_height,
image_patch_size= image_patch_size,
image_in_channels = dit_image_inchannel,
# ----- motion
motion_token_num = motion_token_num,
motion_in_channels = motion_token_channel,
motion_target_num_frame = video_frames)
elif diffusion_model_type == 'doubleref':
dit_image_inchannel = image_inchannel
self.diffusion_transformer = AMDDiffusionTransformerModelImgSpatialDoubleRef(num_attention_heads= diffusion_attn_num_heads,
attention_head_dim= diffusion_attn_head_dim,
out_channels = diffusion_out_channels,
num_layers= diffusion_num_layers,
# ----- img
image_width= image_width,
image_height= image_height,
image_patch_size= image_patch_size,
image_in_channels = dit_image_inchannel,
# ----- motion
motion_token_num = motion_token_num,
motion_in_channels = motion_token_channel,
motion_target_num_frame = video_frames)
else:
raise IndexError
def forward(self,
video:torch.tensor,
ref_img:torch.Tensor ,
randomref_img:torch.Tensor = None,
time_step:torch.tensor = None,
return_meta_info=False,
mask_ratio=None,
**kwargs,):
"""
Args:
video: (N,T,C,H,W)
ref_img: (N,T,C,H,W)
randomref_img : (N,T,C,H,W)
"""
device = video.device
n,t,c,h,w = video.shape
assert video.shape == ref_img.shape ,f'video.shape:{video.shape}should be equal to ref_img.shape:{ref_img.shape}'
if self.diffusion_model_type == 'doubleref' :
assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given"
# motion encoder
if mask_ratio is not None:
mask_ratio = torch.rand(1).item() * mask_ratio
if self.diffusion_model_type == 'doubleref' and randomref_img is not None:
if randomref_img.dim()==4:
randomref_img = randomref_img.unsqueeze(1).repeat(1,t,1,1,1)
refimg_and_video = torch.cat([randomref_img,video],dim=1)# (n,t+t,C,H,W)
else:
refimg_and_video = torch.cat([ref_img,video],dim=1)# (n,t+t,C,H,W)
motion = self.motion_encoder(refimg_and_video,mask_ratio) # (n,t+t,l,d)
source_motion = motion[:,:t].flatten(0,1) # (NT,motion_token,d)
target_motion = motion[:,t:].flatten(0,1) # (NT,motion_token,d)
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}'
# motion transformer
if self.need_motion_transformer:
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n)
target_motion = self.motion_transformer(target_motion)
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n)
# prepare for Diffusion Transformer
zi = ref_img.flatten(0,1) # (NT,C,H,W)
zj = video.flatten(0,1) # (NT,C,H,W)
if self.diffusion_model_type == 'doubleref' and randomref_img is not None:
randomref_img = randomref_img.flatten(0,1) # (NT,C,H,W)
if time_step is None:
time_step = self.prepare_timestep(batch_size= zj.shape[0],device= device) #(b,)
if self.diffusion_model_type != 'default':
time_step = self.prepare_timestep(batch_size= n,device= device) # (n,)
time_step = time_step.repeat_interleave(t) # (b,)
zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) # (NT,C,H,W),(NT,C,H,W)
# dit forward
if self.refimg_drop:
zi = torch.zeros_like(zi).to(video.device)
image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W)
pre = self.diffusion_transformer(motion_source_hidden_states = source_motion,
motion_target_hidden_states = target_motion,
image_hidden_states = image_hidden_states,
randomref_image_hidden_states = randomref_img,
timestep = time_step,)
# loss
diff_loss = l2(pre,vel)
rec_zj = self.scheduler.get_target_with_zt_vel(zt,pre,time_step)
rec_loss = l2(rec_zj,zj)
loss = diff_loss
loss_dict = {'loss':loss,'diff_loss':diff_loss,'rec_loss':rec_loss}
if return_meta_info:
return {'motion' : motion, # (,t,motion_out_channels,h,w) , output of motion transformer
'zi' : zi, # (b,C,H,W) | b = n * t
'zj' : zj, # (b,C,H,W)
'zt' : zt, # (b,C,H,W)
'gt' : vel, # (b,C,H,W)
'pre': pre, # (b,C,H,W)
'time_step': time_step, # (b,)
}
else:
return pre,vel,loss_dict # (b,C,H,W)
def get_noise_latent_pair(self,
video:torch.Tensor,
ref_img:torch.Tensor ,
randomref_img:torch.Tensor,
sample_step:int = 50,
):
pass
@torch.no_grad()
def sample(self,video:torch.Tensor,
ref_img:torch.Tensor ,
randomref_img:torch.Tensor = None,
sample_step:int = 50,
mask_ratio = None,
start_step:int = None,
return_meta_info=False,
**kwargs,):
device = video.device
n,t,c,h,w = video.shape
if start_step is None:
start_step = self.scheduler.num_step
assert start_step <= self.scheduler.num_step , 'start_step cant be larger than scheduler.num_step'
if self.diffusion_model_type == 'doubleref' :
assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given"
if ref_img.dim()==4:
ref_img = ref_img.unsqueeze(1).repeat(1,t,1,1,1)
# motion encoder
if mask_ratio is not None:
print(f'* Sampling with Mask_Ratio = {mask_ratio}')
mask_ratio = mask_ratio
if self.diffusion_model_type == 'doubleref' and randomref_img is not None:
if randomref_img.dim()==4:
randomref_img = randomref_img.unsqueeze(1).repeat(1,t,1,1,1)
refimg_and_video = torch.cat([randomref_img,video],dim=1)# (n,t+t,C,H,W)
else:
refimg_and_video = torch.cat([ref_img,video],dim=1)# (n,t+t,C,H,W)
motion = self.motion_encoder(refimg_and_video,mask_ratio) # (n,t+t,motion_out_channels,h,w)
source_motion = motion[:,:t].flatten(0,1) # (NT,motion_token,d)
target_motion = motion[:,t:].flatten(0,1) # (NT,motion_token,d)
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}'
# motion transformer
if self.need_motion_transformer:
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n)
target_motion = self.motion_transformer(target_motion)
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n)
# prepare for Diffusion Transformer
time_step = torch.ones((source_motion.shape[0],)).to(device)
time_step = time_step * start_step
zi = ref_img.flatten(0,1) # (NT,C,H,W)
zj = video.flatten(0,1) # (NT,C,H,W)
if self.diffusion_model_type == 'doubleref' and randomref_img is not None:
randomref_img = randomref_img.flatten(0,1) # (NT,C,H,W)
zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) # (NT,C,H,W),(NT,C,H,W)
noise = zj - vel
# Sample Loop
pre_cache = []
sample_cache = []
# 1.step_seq
step_seq = np.linspace(0, start_step, num=sample_step+1, endpoint=True,dtype=int) # [0,5,10,15,....,start_step]
step_seq = list(reversed(step_seq[1:])) # delete step:0 [start_step,.....,15,10,5]
# 2.Euler step
dt = 1./sample_step
if self.refimg_drop:
zi = torch.zeros_like(zi).to(video.device)
for i in tqdm(step_seq):
# time_step
time_step = torch.ones((zt.shape[0],)).to(zt.device)
time_step = time_step * i
# input
zt = zt.to(video.dtype)
image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W)
# forward
pre = self.diffusion_transformer(motion_source_hidden_states = source_motion,
motion_target_hidden_states = target_motion,
image_hidden_states = image_hidden_states,
randomref_image_hidden_states = randomref_img,
timestep = time_step,)
zt = zt + pre * dt
pre_cache.append(pre)
sample_cache.append(zt)
zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n)
zt = einops.rearrange(zt,'(n t) c h w -> n t c h w',n=n)
zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n)
if return_meta_info:
return {'zi' : zi, # (b,1,c,h,w)
'zj' : zj, # (b,1,c,h,w)
'sample' : zt, # (b,1,c,h,w)
'pre_cache' : pre_cache, # [(b,c,h,w),....]
'sample_cache' : sample_cache, # [(b,c,h,w),....]
'step_seq' : step_seq,
'motion' : target_motion, # (b,C,H,W),
"noise" : noise
}
else:
return zi,zt,zj # (n,t,c,h,w)
@torch.no_grad()
def sample_with_refimg_motion(self,
ref_img:torch.Tensor,
motion=torch.Tensor,
randomref_img:torch.Tensor = None,
sample_step:int = 10,
mask_ratio = None,
return_meta_info=False,
**kwargs,):
"""
Args:
ref_img : (N,C,H,W)
randomref_img : (N,C,H,W)
motion : (N,F,L,D)
Return:
video : (N,T,C,H,W)
"""
device = motion.device
n,t,l,d = motion.shape
start_step = self.scheduler.num_step
# motion encoder
refimg = ref_img.unsqueeze(1) # (N,1,C,H,W)
if self.diffusion_model_type == 'doubleref' :
assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given"
if self.diffusion_model_type == 'doubleref' and randomref_img is not None:
print('* Warnning * diffusion_model_type:doubleref')
if randomref_img.dim()==4:
randomref_img = randomref_img.unsqueeze(1) # (N,1,C,H,W)
source_motion = self.motion_encoder(randomref_img,mask_ratio) # (n,1,motion_token,d)
else:
source_motion = self.motion_encoder(refimg,mask_ratio) # (n,1,motion_token,d)
source_motion = source_motion.repeat(1,t,1,1).flatten(0,1) # (NT,l,d)
target_motion = motion.flatten(0,1) # (NT,l,d)
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}'
# motion transformer
if self.need_motion_transformer and not self.extract_motion_with_motion_transformer:
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n)
target_motion = self.motion_transformer(target_motion)
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n)
# prepare for Diffusion Transformer
time_step = torch.ones((source_motion.shape[0],)).to(device)
time_step = time_step * start_step
zi = refimg.repeat(1,t,1,1,1).flatten(0,1) # (NT,C,H,W)
zj = zi
if self.diffusion_model_type == 'doubleref' and randomref_img is not None:
randomref_img = randomref_img.repeat(1,t,1,1,1)
randomref_img = randomref_img.flatten(0,1) # (NT,C,H,W)
# zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) # (NT,C,H,W),(NT,C,H,W)
zt = torch.randn_like(zj)
# Sample Loop
pre_cache = []
sample_cache = []
# 1.step_seq
step_seq = np.linspace(0, start_step, num=sample_step+1, endpoint=True,dtype=int) # [0,5,10,15,....,start_step]
step_seq = list(reversed(step_seq[1:])) # delete step:0 [start_step,.....,15,10,5]
# 2.Euler step
dt = 1./sample_step
if self.refimg_drop:
zi = torch.zeros_like(zi).to(ref_img.device)
for i in tqdm(step_seq):
# time_step
time_step = torch.ones((zt.shape[0],)).to(zt.device)
time_step = time_step * i
# input
zt = zt.to(ref_img.dtype)
image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W)
# forward
pre = self.diffusion_transformer(motion_source_hidden_states = source_motion,
motion_target_hidden_states = target_motion,
image_hidden_states = image_hidden_states,
randomref_image_hidden_states = randomref_img,
timestep = time_step,)
zt = zt + pre * dt
# unsqueeze (n,1,c,h,w) means images, (n,t,c,h,w) means video t>1 .
zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n,t=t)
zt = einops.rearrange(zt,'(n t) c h w -> n t c h w',n=n,t=t)
if return_meta_info:
return {'zi' : zi, # (b,1,c,h,w)
'zj' : zj, # (b,1,c,h,w)
'sample' : zt, # (b,1,c,h,w)
'pre_cache' : pre_cache, # [(b,c,h,w),....]
'sample_cache' : sample_cache, # [(b,c,h,w),....]
'step_seq' : step_seq,
'motion' : target_motion, # (b,C,H,W)
}
else:
return zi,zt,zj # (b,1,c,h,w)
def extract_motion(self,video:torch.tensor,mask_ratio=None):
# video : (N,T,C,H,W)
n,t,c,h,w = video.shape
motion = self.motion_encoder(video,mask_ratio) # (N,T,L,D)
if self.need_motion_transformer and self.extract_motion_with_motion_transformer:
motion = self.motion_transformer(motion) # (N,T,L,D)
return motion
def prepare_timestep(self,batch_size:int,device,time_step = None):
if time_step is not None:
return time_step.to(device)
else:
return torch.randint(0,self.num_step+1,(batch_size,)).to(device)
def prepare_encoder_input(self,video:torch.tensor):
assert len(video.shape) == 5 , f'only support video data : 5D tensor , but got {video.shape}'
# cat
pre = video[:,:-1,:,:,:]
post= video[:,1:,:,:,:]
duo_frame_mix = torch.cat([pre,post],dim=2) # (b,t-1,2c,h,w)
duo_frame_mix = einops.rearrange(duo_frame_mix,'b t c h w -> (b t) c h w')
return duo_frame_mix # (b*f-1,2c,h,w)
def unpatchify(self, x ,patch_size):
"""
x: (N, S, patch_size**2 *C)
imgs: (N, C, H, W)
"""
p = patch_size
h = w = int(x.shape[1]**.5)
# c = self.in_chans
c = x.shape[2] // (p**2)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) # (N, h, w, p, p, c)
x = torch.einsum('nhwpqc->nchpwq', x) # (N, c, h, p, w, p)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs #(N,C,H,W)
def reset_infer_num_frame(self, num:int):
old_num = self.diffusion_transformer.target_frame
self.diffusion_transformer.target_frame = num
print(f'* Reset infer frame from {old_num} to {self.diffusion_transformer.target_frame} *')
class AMDModel_Rec(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(self,
image_inchannel :int = 4,
image_height :int = 32,
image_width :int = 32,
video_frames :int = 16,
scheduler_num_step :int = 1000,
# ----------- MotionEncoder -----------
motion_token_num:int = 12,
motion_token_channel: int = 128,
enc_num_layers:int = 8,
enc_nhead:int = 8,
enc_ndim:int = 64,
enc_dropout:float = 0.0,
motion_need_norm_out:bool = True,
# ----------- MotionTransformer ---------
need_motion_transformer :bool = False,
motion_transformer_attn_head_dim:int = 64,
motion_transformer_attn_num_heads:int = 16,
motion_transformer_num_layers:int = 4,
# ----------- Diffusion Transformer -----------
diffusion_model_type : str = 'default', # or dual
diffusion_attn_head_dim : int = 64,
diffusion_attn_num_heads : int = 16,
diffusion_out_channels : int = 4,
diffusion_num_layers : int = 16,
image_patch_size : int = 2,
motion_patch_size : int = 1,
motion_drop_ratio: float = 0.0,
**kwargs,
):
super().__init__()
# setting
self.num_step = scheduler_num_step
self.scheduler = RectifiedFlow(num_steps=scheduler_num_step)
self.need_motion_transformer = need_motion_transformer
# zt token
INIT_CONST = 0.02
self.zt_token = nn.Parameter(torch.randn(1, image_inchannel, image_height,image_width) * INIT_CONST)
# motion Encoder
self.motion_encoder = MotionEncoderLearnTokenTransformer(img_height = image_height,
img_width=image_width,
img_inchannel=image_inchannel,
img_patch_size = image_patch_size,
motion_token_num = motion_token_num,
motion_channel = motion_token_channel,
need_norm_out = motion_need_norm_out,
# ----- attention
num_attention_heads=enc_nhead,
attention_head_dim=enc_ndim,
num_layers=enc_num_layers,
dropout=enc_dropout,
attention_bias= True,)
# motion transformer
if need_motion_transformer:
self.motion_transformer = MotionTransformer(motion_token_num=motion_token_num,
motion_token_channel=motion_token_channel,
attention_head_dim=motion_transformer_attn_head_dim,
num_attention_heads=motion_transformer_attn_num_heads,
num_layers=motion_transformer_num_layers,)
# diffusion transformer
if diffusion_model_type == 'default':
dit_image_inchannel = image_inchannel * 2 # zi + zt
self.transformer = AMDReconstructTransformerModel(num_attention_heads= diffusion_attn_num_heads,
attention_head_dim= diffusion_attn_head_dim,
out_channels = diffusion_out_channels,
num_layers= diffusion_num_layers,
# ----- img
image_width= image_width,
image_height= image_height,
image_patch_size= image_patch_size,
image_in_channels = dit_image_inchannel,
# ----- motion
motion_token_num = motion_token_num,
motion_in_channels = motion_token_channel,)
elif diffusion_model_type == 'spatial':
dit_image_inchannel = image_inchannel * 2 # zi + zt
self.transformer = AMDReconstructTransformerModelSpatial(num_attention_heads= diffusion_attn_num_heads,
attention_head_dim= diffusion_attn_head_dim,
out_channels = diffusion_out_channels,
num_layers= diffusion_num_layers,
# ----- img
image_width= image_width,
image_height= image_height,
image_patch_size= image_patch_size,
image_in_channels = dit_image_inchannel,
# ----- motion
motion_token_num = motion_token_num,
motion_in_channels = motion_token_channel,
motion_target_num_frame = video_frames)
def forward(self,
video:torch.tensor,
ref_img:torch.Tensor ,
time_step:torch.tensor = None,
return_meta_info=False,
**kwargs,):
"""
Args:
video: (N,T,C,H,W)
ref_img: (N,T,C,H,W)
"""
device = video.device
n,t,c,h,w = video.shape
assert video.shape == ref_img.shape ,f'video.shape:{video.shape}should be equal to ref_img.shape:{ref_img.shape}'
# motion encoder
refimg_and_video = torch.cat([ref_img,video],dim=1)# (n,t+t,C,H,W)
motion = self.motion_encoder(refimg_and_video) # (n,t+t,l,d)
source_motion = motion[:,:t].flatten(0,1) # (NT,motion_token,d)
target_motion = motion[:,t:].flatten(0,1) # (NT,motion_token,d)
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}'
# motion transformer
if self.need_motion_transformer:
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n)
target_motion = self.motion_transformer(target_motion)
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n)
# prepare for Diffusion Transformer
zi = ref_img.flatten(0,1) # (NT,C,H,W)
zj = video.flatten(0,1) # (NT,C,H,W)
zt = self.zt_token.repeat(zj.shape[0],1,1,1) # (NT,C,H,W)
# dit forward
image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W)
pre = self.transformer(motion_source_hidden_states = source_motion,
motion_target_hidden_states = target_motion,
image_hidden_states = image_hidden_states,)
# loss
rec_loss = l2(pre,zj)
loss = rec_loss
loss_dict = {'loss':loss,'rec_loss':rec_loss}
if return_meta_info:
return {'motion' : motion, # (,t,motion_out_channels,h,w) , output of motion transformer
'zi' : zi, # (b,C,H,W) | b = n * t
'zj' : zj, # (b,C,H,W)
'zt' : zt, # (b,C,H,W)
'pre': pre, # (b,C,H,W)
'time_step': time_step, # (b,)
}
else:
return pre,zj,loss_dict # (b,C,H,W)
@torch.no_grad()
def sample(self,
video:torch.tensor,
ref_img:torch.Tensor ,
sample_step:int = 50,
start_step:int = None,
return_meta_info=False,
**kwargs,):
device = video.device
n,t,c,h,w = video.shape
if start_step is None:
start_step = self.scheduler.num_step
assert start_step <= self.scheduler.num_step , 'start_step cant be larger than scheduler.num_step'
# motion encoder
refimg_and_video = torch.cat([ref_img,video],dim=1)# (n,t+t,C,H,W)
motion = self.motion_encoder(refimg_and_video) # (n,t+t,motion_out_channels,h,w)
source_motion = motion[:,:t].flatten(0,1) # (NT,motion_token,d)
target_motion = motion[:,t:].flatten(0,1) # (NT,motion_token,d)
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}'
# motion transformer
if self.need_motion_transformer:
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n)
target_motion = self.motion_transformer(target_motion)
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n)
zi = ref_img.flatten(0,1) # (NT,C,H,W)
zj = video.flatten(0,1) # (NT,C,H,W)
zt = self.zt_token.repeat(zj.shape[0],1,1,1) # (NT,C,H,W)
# input
zt = zt.to(video.dtype)
image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W)
# forward
pre = self.transformer(motion_source_hidden_states = source_motion,
motion_target_hidden_states = target_motion,
image_hidden_states = image_hidden_states,)
zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n)
zt = einops.rearrange(pre,'(n t) c h w -> n t c h w',n=n)
zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n)
if return_meta_info:
return {'zi' : zi, # (b,1,c,h,w)
'zj' : zj, # (b,1,c,h,w)
}
else:
return zi,zt,zj # (n,t,c,h,w)
def sample_with_refimg_motion(self,
ref_img:torch.Tensor,
motion=torch.Tensor,
sample_step:int = 10,
return_meta_info=False,
**kwargs,):
"""
Args:
ref_img : (N,C,H,W)
motion : (N,F,L,D)
Return:
video : (N,T,C,H,W)
"""
device = motion.device
n,t,l,d = motion.shape
start_step = self.scheduler.num_step
# motion encoder
refimg = ref_img.unsqueeze(1) # (N,1,C,H,W)
source_motion = self.motion_encoder(refimg) # (n,1,motion_token,d)
source_motion = source_motion.repeat(1,t,1,1).flatten(0,1) # (NT,l,d)
target_motion = motion.flatten(0,1) # (NT,l,d)
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}'
# motion transformer
if self.need_motion_transformer:
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n)
target_motion = self.motion_transformer(target_motion)
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n)
# prepare for Diffusion Transformer
time_step = torch.ones((source_motion.shape[0],)).to(device)
time_step = time_step * start_step
zi = refimg.repeat(1,t,1,1,1).flatten(0,1) # (NT,C,H,W)
zj = zi
zt = self.zt_token.repeat(zj.shape[0],1,1,1) # (NT,C,H,W)
# input
zt = zt.to(zj.dtype)
image_hidden_states = torch.cat((zi,zt),dim=1) # (b,2C,H,W)
# forward
pre = self.transformer(motion_source_hidden_states = source_motion,
motion_target_hidden_states = target_motion,
image_hidden_states = image_hidden_states,)
zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n)
zt = einops.rearrange(pre,'(n t) c h w -> n t c h w',n=n)
zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n)
if return_meta_info:
return {'zi' : zi, # (b,1,c,h,w)
'zj' : zj, # (b,1,c,h,w)
}
else:
return zi,zt,zj # (b,1,c,h,w)
def extract_motion(self,video:torch.tensor):
# video : (N,T,C,H,W)
# motion Encoder
motion = self.motion_encoder(video) # (N,T,L,D)
if self.need_motion_transformer:
motion = self.motion_transformer(motion) # (N,T,L,D)
return motion
def AMD_S(**kwargs) -> AMDModel:
return AMDModel(
# ----------- motion encoder -----------
enc_num_layers = 8,
enc_nhead = 8,
enc_ndim = 64,
# ----------- Diffusion Transformer -----------
diffusion_attn_head_dim = 64,
diffusion_attn_num_heads = 16,
diffusion_out_channels = 4,
diffusion_num_layers = 12,
**kwargs)
def AMD_L(**kwargs) -> AMDModel:
return AMDModel(
# ----------- motion encoder -----------
enc_num_layers = 8,
enc_nhead = 16,
enc_ndim = 64,
# ----------- Diffusion Transformer -----------
diffusion_attn_head_dim = 96,
diffusion_attn_num_heads = 16,
diffusion_out_channels = 4,
diffusion_num_layers = 16,
**kwargs)
def AMD_S_Rec(**kwargs) -> AMDModel:
return AMDModel_Rec(
# ----------- motion encoder -----------
enc_num_layers = 8,
enc_nhead = 8,
enc_ndim = 64,
# ----------- Diffusion Transformer -----------
diffusion_attn_head_dim = 64,
diffusion_attn_num_heads = 16,
diffusion_out_channels = 4,
diffusion_num_layers = 12,
**kwargs)
def AMD_S_RecSplit(**kwargs) -> AMDModel:
return AMDModel_Rec(
# ----------- motion encoder -----------
enc_num_layers = 8,
enc_nhead = 8,
enc_ndim = 64,
# ----------- Diffusion Transformer -----------
diffusion_attn_head_dim = 64,
diffusion_attn_num_heads = 16,
diffusion_out_channels = 4,
diffusion_num_layers = 12,
is_split = True,
**kwargs)
AMD_models = {
"AMD_S": AMD_S, # 250M
"AMD_L": AMD_L, # 700M
"AMD_S_Rec": AMD_S_Rec, # 250M
"AMD_S_RecSplit" : AMD_S_RecSplit, # 250M
} # S 206 B 333 M 642 L 1053