|
|
import torch |
|
|
from torch import nn |
|
|
import einops |
|
|
from typing import Tuple |
|
|
import random |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from .modules import DuoFrameDownEncoder,Upsampler,MapConv,MotionDownEncoder |
|
|
from .loss import l1,l2 |
|
|
from .transformer import (MotionTransformer, |
|
|
AMDDiffusionTransformerModel, |
|
|
MotionEncoderLearnTokenTransformer, |
|
|
AMDReconstructTransformerModel, |
|
|
AMDDiffusionTransformerModelDualStream, |
|
|
AMDDiffusionTransformerModelImgSpatial, |
|
|
AMDDiffusionTransformerModelImgSpatialDoubleRef, |
|
|
AMDReconstructTransformerModelSpatial) |
|
|
from .rectified_flow import RectifiedFlow |
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
from diffusers.models.resnet import ResnetBlock2D |
|
|
import einops |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from diffusers.utils import export_to_gif |
|
|
|
|
|
class AMDModel(ModelMixin, ConfigMixin): |
|
|
_supports_gradient_checkpointing = True |
|
|
|
|
|
@register_to_config |
|
|
def __init__(self, |
|
|
image_inchannel :int = 4, |
|
|
image_height :int = 32, |
|
|
image_width :int = 32, |
|
|
video_frames :int = 16, |
|
|
scheduler_num_step :int = 1000, |
|
|
|
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_token_channel: int = 128, |
|
|
enc_num_layers:int = 8, |
|
|
enc_nhead:int = 8, |
|
|
enc_ndim:int = 64, |
|
|
enc_dropout:float = 0.0, |
|
|
motion_need_norm_out:bool = False, |
|
|
|
|
|
|
|
|
need_motion_transformer :bool = False, |
|
|
motion_transformer_attn_head_dim:int = 64, |
|
|
motion_transformer_attn_num_heads:int = 16, |
|
|
motion_transformer_num_layers:int = 4, |
|
|
|
|
|
|
|
|
diffusion_model_type : str = 'default', |
|
|
diffusion_attn_head_dim : int = 64, |
|
|
diffusion_attn_num_heads : int = 16, |
|
|
diffusion_out_channels : int = 4, |
|
|
diffusion_num_layers : int = 16, |
|
|
image_patch_size : int = 2, |
|
|
motion_patch_size : int = 1, |
|
|
motion_drop_ratio: float = 0.0, |
|
|
refimg_drop: bool = False, |
|
|
|
|
|
|
|
|
extract_motion_with_motion_transformer = False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.num_step = scheduler_num_step |
|
|
self.scheduler = RectifiedFlow(num_steps=scheduler_num_step) |
|
|
self.need_motion_transformer = need_motion_transformer |
|
|
self.extract_motion_with_motion_transformer = extract_motion_with_motion_transformer |
|
|
self.diffusion_model_type = diffusion_model_type |
|
|
self.target_frame = video_frames |
|
|
self.refimg_drop = refimg_drop |
|
|
|
|
|
|
|
|
self.motion_encoder = MotionEncoderLearnTokenTransformer(img_height = image_height, |
|
|
img_width=image_width, |
|
|
img_inchannel=image_inchannel, |
|
|
img_patch_size = image_patch_size, |
|
|
motion_token_num = motion_token_num, |
|
|
motion_channel = motion_token_channel, |
|
|
need_norm_out = motion_need_norm_out, |
|
|
|
|
|
num_attention_heads=enc_nhead, |
|
|
attention_head_dim=enc_ndim, |
|
|
num_layers=enc_num_layers, |
|
|
dropout=enc_dropout, |
|
|
attention_bias= True,) |
|
|
|
|
|
|
|
|
if need_motion_transformer: |
|
|
self.motion_transformer = MotionTransformer(motion_token_num=motion_token_num, |
|
|
motion_token_channel=motion_token_channel, |
|
|
attention_head_dim=motion_transformer_attn_head_dim, |
|
|
num_attention_heads=motion_transformer_attn_num_heads, |
|
|
num_layers=motion_transformer_num_layers,) |
|
|
|
|
|
|
|
|
|
|
|
if diffusion_model_type == 'default': |
|
|
dit_image_inchannel = image_inchannel * 2 |
|
|
self.diffusion_transformer = AMDDiffusionTransformerModel(num_attention_heads= diffusion_attn_num_heads, |
|
|
attention_head_dim= diffusion_attn_head_dim, |
|
|
out_channels = diffusion_out_channels, |
|
|
num_layers= diffusion_num_layers, |
|
|
|
|
|
image_width= image_width, |
|
|
image_height= image_height, |
|
|
image_patch_size= image_patch_size, |
|
|
image_in_channels = dit_image_inchannel, |
|
|
|
|
|
motion_token_num = motion_token_num, |
|
|
motion_in_channels = motion_token_channel,) |
|
|
elif diffusion_model_type == 'dual': |
|
|
dit_image_inchannel = image_inchannel * 2 |
|
|
self.diffusion_transformer = AMDDiffusionTransformerModelDualStream(num_attention_heads= diffusion_attn_num_heads, |
|
|
attention_head_dim= diffusion_attn_head_dim, |
|
|
out_channels = diffusion_out_channels, |
|
|
num_layers= diffusion_num_layers, |
|
|
|
|
|
image_width= image_width, |
|
|
image_height= image_height, |
|
|
image_patch_size= image_patch_size, |
|
|
image_in_channels = dit_image_inchannel, |
|
|
|
|
|
motion_token_num = motion_token_num, |
|
|
motion_in_channels = motion_token_channel, |
|
|
motion_target_num_frame = video_frames) |
|
|
elif diffusion_model_type == 'spatial': |
|
|
dit_image_inchannel = image_inchannel * 2 |
|
|
self.diffusion_transformer = AMDDiffusionTransformerModelImgSpatial(num_attention_heads= diffusion_attn_num_heads, |
|
|
attention_head_dim= diffusion_attn_head_dim, |
|
|
out_channels = diffusion_out_channels, |
|
|
num_layers= diffusion_num_layers, |
|
|
|
|
|
image_width= image_width, |
|
|
image_height= image_height, |
|
|
image_patch_size= image_patch_size, |
|
|
image_in_channels = dit_image_inchannel, |
|
|
|
|
|
motion_token_num = motion_token_num, |
|
|
motion_in_channels = motion_token_channel, |
|
|
motion_target_num_frame = video_frames) |
|
|
elif diffusion_model_type == 'doubleref': |
|
|
dit_image_inchannel = image_inchannel |
|
|
self.diffusion_transformer = AMDDiffusionTransformerModelImgSpatialDoubleRef(num_attention_heads= diffusion_attn_num_heads, |
|
|
attention_head_dim= diffusion_attn_head_dim, |
|
|
out_channels = diffusion_out_channels, |
|
|
num_layers= diffusion_num_layers, |
|
|
|
|
|
image_width= image_width, |
|
|
image_height= image_height, |
|
|
image_patch_size= image_patch_size, |
|
|
image_in_channels = dit_image_inchannel, |
|
|
|
|
|
motion_token_num = motion_token_num, |
|
|
motion_in_channels = motion_token_channel, |
|
|
motion_target_num_frame = video_frames) |
|
|
else: |
|
|
raise IndexError |
|
|
|
|
|
def forward(self, |
|
|
video:torch.tensor, |
|
|
ref_img:torch.Tensor , |
|
|
randomref_img:torch.Tensor = None, |
|
|
time_step:torch.tensor = None, |
|
|
return_meta_info=False, |
|
|
mask_ratio=None, |
|
|
**kwargs,): |
|
|
""" |
|
|
Args: |
|
|
video: (N,T,C,H,W) |
|
|
ref_img: (N,T,C,H,W) |
|
|
randomref_img : (N,T,C,H,W) |
|
|
""" |
|
|
|
|
|
device = video.device |
|
|
n,t,c,h,w = video.shape |
|
|
|
|
|
assert video.shape == ref_img.shape ,f'video.shape:{video.shape}should be equal to ref_img.shape:{ref_img.shape}' |
|
|
if self.diffusion_model_type == 'doubleref' : |
|
|
assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given" |
|
|
|
|
|
|
|
|
if mask_ratio is not None: |
|
|
mask_ratio = torch.rand(1).item() * mask_ratio |
|
|
|
|
|
if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
|
|
if randomref_img.dim()==4: |
|
|
randomref_img = randomref_img.unsqueeze(1).repeat(1,t,1,1,1) |
|
|
refimg_and_video = torch.cat([randomref_img,video],dim=1) |
|
|
else: |
|
|
refimg_and_video = torch.cat([ref_img,video],dim=1) |
|
|
motion = self.motion_encoder(refimg_and_video,mask_ratio) |
|
|
|
|
|
source_motion = motion[:,:t].flatten(0,1) |
|
|
target_motion = motion[:,t:].flatten(0,1) |
|
|
|
|
|
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
|
|
|
|
|
|
|
|
if self.need_motion_transformer: |
|
|
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
|
|
target_motion = self.motion_transformer(target_motion) |
|
|
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
|
|
|
|
|
|
|
|
|
|
|
zi = ref_img.flatten(0,1) |
|
|
zj = video.flatten(0,1) |
|
|
if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
|
|
randomref_img = randomref_img.flatten(0,1) |
|
|
|
|
|
if time_step is None: |
|
|
time_step = self.prepare_timestep(batch_size= zj.shape[0],device= device) |
|
|
if self.diffusion_model_type != 'default': |
|
|
time_step = self.prepare_timestep(batch_size= n,device= device) |
|
|
time_step = time_step.repeat_interleave(t) |
|
|
zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) |
|
|
|
|
|
|
|
|
if self.refimg_drop: |
|
|
zi = torch.zeros_like(zi).to(video.device) |
|
|
image_hidden_states = torch.cat((zi,zt),dim=1) |
|
|
|
|
|
|
|
|
pre = self.diffusion_transformer(motion_source_hidden_states = source_motion, |
|
|
motion_target_hidden_states = target_motion, |
|
|
image_hidden_states = image_hidden_states, |
|
|
randomref_image_hidden_states = randomref_img, |
|
|
timestep = time_step,) |
|
|
|
|
|
|
|
|
diff_loss = l2(pre,vel) |
|
|
|
|
|
rec_zj = self.scheduler.get_target_with_zt_vel(zt,pre,time_step) |
|
|
rec_loss = l2(rec_zj,zj) |
|
|
|
|
|
loss = diff_loss |
|
|
|
|
|
loss_dict = {'loss':loss,'diff_loss':diff_loss,'rec_loss':rec_loss} |
|
|
|
|
|
if return_meta_info: |
|
|
return {'motion' : motion, |
|
|
'zi' : zi, |
|
|
'zj' : zj, |
|
|
'zt' : zt, |
|
|
'gt' : vel, |
|
|
'pre': pre, |
|
|
'time_step': time_step, |
|
|
} |
|
|
else: |
|
|
return pre,vel,loss_dict |
|
|
def get_noise_latent_pair(self, |
|
|
video:torch.Tensor, |
|
|
ref_img:torch.Tensor , |
|
|
randomref_img:torch.Tensor, |
|
|
sample_step:int = 50, |
|
|
): |
|
|
pass |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample(self,video:torch.Tensor, |
|
|
ref_img:torch.Tensor , |
|
|
randomref_img:torch.Tensor = None, |
|
|
sample_step:int = 50, |
|
|
mask_ratio = None, |
|
|
start_step:int = None, |
|
|
return_meta_info=False, |
|
|
**kwargs,): |
|
|
|
|
|
device = video.device |
|
|
n,t,c,h,w = video.shape |
|
|
|
|
|
if start_step is None: |
|
|
start_step = self.scheduler.num_step |
|
|
assert start_step <= self.scheduler.num_step , 'start_step cant be larger than scheduler.num_step' |
|
|
|
|
|
if self.diffusion_model_type == 'doubleref' : |
|
|
assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given" |
|
|
|
|
|
if ref_img.dim()==4: |
|
|
ref_img = ref_img.unsqueeze(1).repeat(1,t,1,1,1) |
|
|
|
|
|
|
|
|
if mask_ratio is not None: |
|
|
print(f'* Sampling with Mask_Ratio = {mask_ratio}') |
|
|
mask_ratio = mask_ratio |
|
|
|
|
|
if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
|
|
if randomref_img.dim()==4: |
|
|
randomref_img = randomref_img.unsqueeze(1).repeat(1,t,1,1,1) |
|
|
refimg_and_video = torch.cat([randomref_img,video],dim=1) |
|
|
else: |
|
|
refimg_and_video = torch.cat([ref_img,video],dim=1) |
|
|
|
|
|
motion = self.motion_encoder(refimg_and_video,mask_ratio) |
|
|
|
|
|
source_motion = motion[:,:t].flatten(0,1) |
|
|
target_motion = motion[:,t:].flatten(0,1) |
|
|
|
|
|
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
|
|
|
|
|
|
|
|
if self.need_motion_transformer: |
|
|
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
|
|
target_motion = self.motion_transformer(target_motion) |
|
|
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
|
|
|
|
|
|
|
|
time_step = torch.ones((source_motion.shape[0],)).to(device) |
|
|
time_step = time_step * start_step |
|
|
|
|
|
zi = ref_img.flatten(0,1) |
|
|
zj = video.flatten(0,1) |
|
|
if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
|
|
randomref_img = randomref_img.flatten(0,1) |
|
|
zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) |
|
|
noise = zj - vel |
|
|
|
|
|
pre_cache = [] |
|
|
sample_cache = [] |
|
|
|
|
|
|
|
|
step_seq = np.linspace(0, start_step, num=sample_step+1, endpoint=True,dtype=int) |
|
|
step_seq = list(reversed(step_seq[1:])) |
|
|
|
|
|
|
|
|
dt = 1./sample_step |
|
|
|
|
|
if self.refimg_drop: |
|
|
zi = torch.zeros_like(zi).to(video.device) |
|
|
|
|
|
for i in tqdm(step_seq): |
|
|
|
|
|
time_step = torch.ones((zt.shape[0],)).to(zt.device) |
|
|
time_step = time_step * i |
|
|
|
|
|
|
|
|
zt = zt.to(video.dtype) |
|
|
image_hidden_states = torch.cat((zi,zt),dim=1) |
|
|
|
|
|
|
|
|
pre = self.diffusion_transformer(motion_source_hidden_states = source_motion, |
|
|
motion_target_hidden_states = target_motion, |
|
|
image_hidden_states = image_hidden_states, |
|
|
randomref_image_hidden_states = randomref_img, |
|
|
timestep = time_step,) |
|
|
zt = zt + pre * dt |
|
|
pre_cache.append(pre) |
|
|
sample_cache.append(zt) |
|
|
|
|
|
zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n) |
|
|
zt = einops.rearrange(zt,'(n t) c h w -> n t c h w',n=n) |
|
|
zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n) |
|
|
|
|
|
if return_meta_info: |
|
|
return {'zi' : zi, |
|
|
'zj' : zj, |
|
|
'sample' : zt, |
|
|
'pre_cache' : pre_cache, |
|
|
'sample_cache' : sample_cache, |
|
|
'step_seq' : step_seq, |
|
|
'motion' : target_motion, |
|
|
"noise" : noise |
|
|
} |
|
|
else: |
|
|
return zi,zt,zj |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_with_refimg_motion(self, |
|
|
ref_img:torch.Tensor, |
|
|
motion=torch.Tensor, |
|
|
randomref_img:torch.Tensor = None, |
|
|
sample_step:int = 10, |
|
|
mask_ratio = None, |
|
|
return_meta_info=False, |
|
|
**kwargs,): |
|
|
""" |
|
|
Args: |
|
|
ref_img : (N,C,H,W) |
|
|
randomref_img : (N,C,H,W) |
|
|
motion : (N,F,L,D) |
|
|
Return: |
|
|
video : (N,T,C,H,W) |
|
|
""" |
|
|
device = motion.device |
|
|
n,t,l,d = motion.shape |
|
|
|
|
|
start_step = self.scheduler.num_step |
|
|
|
|
|
|
|
|
refimg = ref_img.unsqueeze(1) |
|
|
if self.diffusion_model_type == 'doubleref' : |
|
|
assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given" |
|
|
|
|
|
if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
|
|
print('* Warnning * diffusion_model_type:doubleref') |
|
|
if randomref_img.dim()==4: |
|
|
randomref_img = randomref_img.unsqueeze(1) |
|
|
source_motion = self.motion_encoder(randomref_img,mask_ratio) |
|
|
else: |
|
|
source_motion = self.motion_encoder(refimg,mask_ratio) |
|
|
|
|
|
source_motion = source_motion.repeat(1,t,1,1).flatten(0,1) |
|
|
target_motion = motion.flatten(0,1) |
|
|
|
|
|
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
|
|
|
|
|
|
|
|
if self.need_motion_transformer and not self.extract_motion_with_motion_transformer: |
|
|
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
|
|
target_motion = self.motion_transformer(target_motion) |
|
|
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
|
|
|
|
|
|
|
|
time_step = torch.ones((source_motion.shape[0],)).to(device) |
|
|
time_step = time_step * start_step |
|
|
|
|
|
zi = refimg.repeat(1,t,1,1,1).flatten(0,1) |
|
|
zj = zi |
|
|
if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
|
|
randomref_img = randomref_img.repeat(1,t,1,1,1) |
|
|
randomref_img = randomref_img.flatten(0,1) |
|
|
|
|
|
zt = torch.randn_like(zj) |
|
|
|
|
|
|
|
|
pre_cache = [] |
|
|
sample_cache = [] |
|
|
|
|
|
|
|
|
step_seq = np.linspace(0, start_step, num=sample_step+1, endpoint=True,dtype=int) |
|
|
step_seq = list(reversed(step_seq[1:])) |
|
|
|
|
|
|
|
|
dt = 1./sample_step |
|
|
|
|
|
if self.refimg_drop: |
|
|
zi = torch.zeros_like(zi).to(ref_img.device) |
|
|
|
|
|
for i in tqdm(step_seq): |
|
|
|
|
|
time_step = torch.ones((zt.shape[0],)).to(zt.device) |
|
|
time_step = time_step * i |
|
|
|
|
|
|
|
|
zt = zt.to(ref_img.dtype) |
|
|
image_hidden_states = torch.cat((zi,zt),dim=1) |
|
|
|
|
|
|
|
|
pre = self.diffusion_transformer(motion_source_hidden_states = source_motion, |
|
|
motion_target_hidden_states = target_motion, |
|
|
image_hidden_states = image_hidden_states, |
|
|
randomref_image_hidden_states = randomref_img, |
|
|
timestep = time_step,) |
|
|
|
|
|
zt = zt + pre * dt |
|
|
|
|
|
|
|
|
zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n,t=t) |
|
|
zt = einops.rearrange(zt,'(n t) c h w -> n t c h w',n=n,t=t) |
|
|
|
|
|
if return_meta_info: |
|
|
return {'zi' : zi, |
|
|
'zj' : zj, |
|
|
'sample' : zt, |
|
|
'pre_cache' : pre_cache, |
|
|
'sample_cache' : sample_cache, |
|
|
'step_seq' : step_seq, |
|
|
'motion' : target_motion, |
|
|
} |
|
|
else: |
|
|
return zi,zt,zj |
|
|
|
|
|
|
|
|
|
|
|
def extract_motion(self,video:torch.tensor,mask_ratio=None): |
|
|
|
|
|
n,t,c,h,w = video.shape |
|
|
|
|
|
motion = self.motion_encoder(video,mask_ratio) |
|
|
|
|
|
if self.need_motion_transformer and self.extract_motion_with_motion_transformer: |
|
|
motion = self.motion_transformer(motion) |
|
|
|
|
|
return motion |
|
|
|
|
|
def prepare_timestep(self,batch_size:int,device,time_step = None): |
|
|
if time_step is not None: |
|
|
return time_step.to(device) |
|
|
else: |
|
|
return torch.randint(0,self.num_step+1,(batch_size,)).to(device) |
|
|
|
|
|
def prepare_encoder_input(self,video:torch.tensor): |
|
|
assert len(video.shape) == 5 , f'only support video data : 5D tensor , but got {video.shape}' |
|
|
|
|
|
|
|
|
pre = video[:,:-1,:,:,:] |
|
|
post= video[:,1:,:,:,:] |
|
|
duo_frame_mix = torch.cat([pre,post],dim=2) |
|
|
duo_frame_mix = einops.rearrange(duo_frame_mix,'b t c h w -> (b t) c h w') |
|
|
|
|
|
return duo_frame_mix |
|
|
|
|
|
|
|
|
def unpatchify(self, x ,patch_size): |
|
|
""" |
|
|
x: (N, S, patch_size**2 *C) |
|
|
imgs: (N, C, H, W) |
|
|
""" |
|
|
p = patch_size |
|
|
h = w = int(x.shape[1]**.5) |
|
|
|
|
|
c = x.shape[2] // (p**2) |
|
|
assert h * w == x.shape[1] |
|
|
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) |
|
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
|
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) |
|
|
return imgs |
|
|
|
|
|
def reset_infer_num_frame(self, num:int): |
|
|
old_num = self.diffusion_transformer.target_frame |
|
|
self.diffusion_transformer.target_frame = num |
|
|
print(f'* Reset infer frame from {old_num} to {self.diffusion_transformer.target_frame} *') |
|
|
|
|
|
|
|
|
class AMDModel_Rec(ModelMixin, ConfigMixin): |
|
|
_supports_gradient_checkpointing = True |
|
|
|
|
|
@register_to_config |
|
|
def __init__(self, |
|
|
image_inchannel :int = 4, |
|
|
image_height :int = 32, |
|
|
image_width :int = 32, |
|
|
video_frames :int = 16, |
|
|
scheduler_num_step :int = 1000, |
|
|
|
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_token_channel: int = 128, |
|
|
enc_num_layers:int = 8, |
|
|
enc_nhead:int = 8, |
|
|
enc_ndim:int = 64, |
|
|
enc_dropout:float = 0.0, |
|
|
motion_need_norm_out:bool = True, |
|
|
|
|
|
|
|
|
need_motion_transformer :bool = False, |
|
|
motion_transformer_attn_head_dim:int = 64, |
|
|
motion_transformer_attn_num_heads:int = 16, |
|
|
motion_transformer_num_layers:int = 4, |
|
|
|
|
|
|
|
|
diffusion_model_type : str = 'default', |
|
|
diffusion_attn_head_dim : int = 64, |
|
|
diffusion_attn_num_heads : int = 16, |
|
|
diffusion_out_channels : int = 4, |
|
|
diffusion_num_layers : int = 16, |
|
|
image_patch_size : int = 2, |
|
|
motion_patch_size : int = 1, |
|
|
motion_drop_ratio: float = 0.0, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.num_step = scheduler_num_step |
|
|
self.scheduler = RectifiedFlow(num_steps=scheduler_num_step) |
|
|
self.need_motion_transformer = need_motion_transformer |
|
|
|
|
|
|
|
|
INIT_CONST = 0.02 |
|
|
self.zt_token = nn.Parameter(torch.randn(1, image_inchannel, image_height,image_width) * INIT_CONST) |
|
|
|
|
|
|
|
|
self.motion_encoder = MotionEncoderLearnTokenTransformer(img_height = image_height, |
|
|
img_width=image_width, |
|
|
img_inchannel=image_inchannel, |
|
|
img_patch_size = image_patch_size, |
|
|
motion_token_num = motion_token_num, |
|
|
motion_channel = motion_token_channel, |
|
|
need_norm_out = motion_need_norm_out, |
|
|
|
|
|
num_attention_heads=enc_nhead, |
|
|
attention_head_dim=enc_ndim, |
|
|
num_layers=enc_num_layers, |
|
|
dropout=enc_dropout, |
|
|
attention_bias= True,) |
|
|
|
|
|
|
|
|
if need_motion_transformer: |
|
|
self.motion_transformer = MotionTransformer(motion_token_num=motion_token_num, |
|
|
motion_token_channel=motion_token_channel, |
|
|
attention_head_dim=motion_transformer_attn_head_dim, |
|
|
num_attention_heads=motion_transformer_attn_num_heads, |
|
|
num_layers=motion_transformer_num_layers,) |
|
|
|
|
|
|
|
|
if diffusion_model_type == 'default': |
|
|
dit_image_inchannel = image_inchannel * 2 |
|
|
self.transformer = AMDReconstructTransformerModel(num_attention_heads= diffusion_attn_num_heads, |
|
|
attention_head_dim= diffusion_attn_head_dim, |
|
|
out_channels = diffusion_out_channels, |
|
|
num_layers= diffusion_num_layers, |
|
|
|
|
|
image_width= image_width, |
|
|
image_height= image_height, |
|
|
image_patch_size= image_patch_size, |
|
|
image_in_channels = dit_image_inchannel, |
|
|
|
|
|
motion_token_num = motion_token_num, |
|
|
motion_in_channels = motion_token_channel,) |
|
|
elif diffusion_model_type == 'spatial': |
|
|
dit_image_inchannel = image_inchannel * 2 |
|
|
self.transformer = AMDReconstructTransformerModelSpatial(num_attention_heads= diffusion_attn_num_heads, |
|
|
attention_head_dim= diffusion_attn_head_dim, |
|
|
out_channels = diffusion_out_channels, |
|
|
num_layers= diffusion_num_layers, |
|
|
|
|
|
image_width= image_width, |
|
|
image_height= image_height, |
|
|
image_patch_size= image_patch_size, |
|
|
image_in_channels = dit_image_inchannel, |
|
|
|
|
|
motion_token_num = motion_token_num, |
|
|
motion_in_channels = motion_token_channel, |
|
|
motion_target_num_frame = video_frames) |
|
|
|
|
|
def forward(self, |
|
|
video:torch.tensor, |
|
|
ref_img:torch.Tensor , |
|
|
time_step:torch.tensor = None, |
|
|
return_meta_info=False, |
|
|
**kwargs,): |
|
|
""" |
|
|
Args: |
|
|
video: (N,T,C,H,W) |
|
|
ref_img: (N,T,C,H,W) |
|
|
""" |
|
|
|
|
|
device = video.device |
|
|
n,t,c,h,w = video.shape |
|
|
|
|
|
assert video.shape == ref_img.shape ,f'video.shape:{video.shape}should be equal to ref_img.shape:{ref_img.shape}' |
|
|
|
|
|
|
|
|
refimg_and_video = torch.cat([ref_img,video],dim=1) |
|
|
motion = self.motion_encoder(refimg_and_video) |
|
|
|
|
|
source_motion = motion[:,:t].flatten(0,1) |
|
|
target_motion = motion[:,t:].flatten(0,1) |
|
|
|
|
|
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
|
|
|
|
|
|
|
|
if self.need_motion_transformer: |
|
|
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
|
|
target_motion = self.motion_transformer(target_motion) |
|
|
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
|
|
|
|
|
|
|
|
|
|
|
zi = ref_img.flatten(0,1) |
|
|
zj = video.flatten(0,1) |
|
|
zt = self.zt_token.repeat(zj.shape[0],1,1,1) |
|
|
|
|
|
|
|
|
image_hidden_states = torch.cat((zi,zt),dim=1) |
|
|
pre = self.transformer(motion_source_hidden_states = source_motion, |
|
|
motion_target_hidden_states = target_motion, |
|
|
image_hidden_states = image_hidden_states,) |
|
|
|
|
|
|
|
|
rec_loss = l2(pre,zj) |
|
|
|
|
|
loss = rec_loss |
|
|
|
|
|
loss_dict = {'loss':loss,'rec_loss':rec_loss} |
|
|
|
|
|
if return_meta_info: |
|
|
return {'motion' : motion, |
|
|
'zi' : zi, |
|
|
'zj' : zj, |
|
|
'zt' : zt, |
|
|
'pre': pre, |
|
|
'time_step': time_step, |
|
|
} |
|
|
else: |
|
|
return pre,zj,loss_dict |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample(self, |
|
|
video:torch.tensor, |
|
|
ref_img:torch.Tensor , |
|
|
sample_step:int = 50, |
|
|
start_step:int = None, |
|
|
return_meta_info=False, |
|
|
**kwargs,): |
|
|
|
|
|
device = video.device |
|
|
n,t,c,h,w = video.shape |
|
|
|
|
|
if start_step is None: |
|
|
start_step = self.scheduler.num_step |
|
|
assert start_step <= self.scheduler.num_step , 'start_step cant be larger than scheduler.num_step' |
|
|
|
|
|
|
|
|
refimg_and_video = torch.cat([ref_img,video],dim=1) |
|
|
motion = self.motion_encoder(refimg_and_video) |
|
|
|
|
|
source_motion = motion[:,:t].flatten(0,1) |
|
|
target_motion = motion[:,t:].flatten(0,1) |
|
|
|
|
|
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
|
|
|
|
|
|
|
|
if self.need_motion_transformer: |
|
|
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
|
|
target_motion = self.motion_transformer(target_motion) |
|
|
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
|
|
|
|
|
zi = ref_img.flatten(0,1) |
|
|
zj = video.flatten(0,1) |
|
|
zt = self.zt_token.repeat(zj.shape[0],1,1,1) |
|
|
|
|
|
|
|
|
|
|
|
zt = zt.to(video.dtype) |
|
|
image_hidden_states = torch.cat((zi,zt),dim=1) |
|
|
|
|
|
|
|
|
pre = self.transformer(motion_source_hidden_states = source_motion, |
|
|
motion_target_hidden_states = target_motion, |
|
|
image_hidden_states = image_hidden_states,) |
|
|
|
|
|
zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n) |
|
|
zt = einops.rearrange(pre,'(n t) c h w -> n t c h w',n=n) |
|
|
zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n) |
|
|
|
|
|
if return_meta_info: |
|
|
return {'zi' : zi, |
|
|
'zj' : zj, |
|
|
} |
|
|
else: |
|
|
return zi,zt,zj |
|
|
|
|
|
def sample_with_refimg_motion(self, |
|
|
ref_img:torch.Tensor, |
|
|
motion=torch.Tensor, |
|
|
sample_step:int = 10, |
|
|
return_meta_info=False, |
|
|
**kwargs,): |
|
|
""" |
|
|
Args: |
|
|
ref_img : (N,C,H,W) |
|
|
motion : (N,F,L,D) |
|
|
Return: |
|
|
video : (N,T,C,H,W) |
|
|
""" |
|
|
device = motion.device |
|
|
n,t,l,d = motion.shape |
|
|
|
|
|
start_step = self.scheduler.num_step |
|
|
|
|
|
|
|
|
refimg = ref_img.unsqueeze(1) |
|
|
source_motion = self.motion_encoder(refimg) |
|
|
|
|
|
source_motion = source_motion.repeat(1,t,1,1).flatten(0,1) |
|
|
target_motion = motion.flatten(0,1) |
|
|
|
|
|
assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
|
|
|
|
|
|
|
|
if self.need_motion_transformer: |
|
|
target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
|
|
target_motion = self.motion_transformer(target_motion) |
|
|
target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
|
|
|
|
|
|
|
|
time_step = torch.ones((source_motion.shape[0],)).to(device) |
|
|
time_step = time_step * start_step |
|
|
|
|
|
zi = refimg.repeat(1,t,1,1,1).flatten(0,1) |
|
|
zj = zi |
|
|
zt = self.zt_token.repeat(zj.shape[0],1,1,1) |
|
|
|
|
|
|
|
|
zt = zt.to(zj.dtype) |
|
|
image_hidden_states = torch.cat((zi,zt),dim=1) |
|
|
|
|
|
|
|
|
pre = self.transformer(motion_source_hidden_states = source_motion, |
|
|
motion_target_hidden_states = target_motion, |
|
|
image_hidden_states = image_hidden_states,) |
|
|
|
|
|
zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n) |
|
|
zt = einops.rearrange(pre,'(n t) c h w -> n t c h w',n=n) |
|
|
zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n) |
|
|
|
|
|
if return_meta_info: |
|
|
return {'zi' : zi, |
|
|
'zj' : zj, |
|
|
} |
|
|
else: |
|
|
return zi,zt,zj |
|
|
|
|
|
def extract_motion(self,video:torch.tensor): |
|
|
|
|
|
|
|
|
|
|
|
motion = self.motion_encoder(video) |
|
|
|
|
|
if self.need_motion_transformer: |
|
|
motion = self.motion_transformer(motion) |
|
|
|
|
|
|
|
|
return motion |
|
|
|
|
|
|
|
|
def AMD_S(**kwargs) -> AMDModel: |
|
|
return AMDModel( |
|
|
|
|
|
enc_num_layers = 8, |
|
|
enc_nhead = 8, |
|
|
enc_ndim = 64, |
|
|
|
|
|
diffusion_attn_head_dim = 64, |
|
|
diffusion_attn_num_heads = 16, |
|
|
diffusion_out_channels = 4, |
|
|
diffusion_num_layers = 12, |
|
|
**kwargs) |
|
|
|
|
|
def AMD_L(**kwargs) -> AMDModel: |
|
|
return AMDModel( |
|
|
|
|
|
enc_num_layers = 8, |
|
|
enc_nhead = 16, |
|
|
enc_ndim = 64, |
|
|
|
|
|
diffusion_attn_head_dim = 96, |
|
|
diffusion_attn_num_heads = 16, |
|
|
diffusion_out_channels = 4, |
|
|
diffusion_num_layers = 16, |
|
|
**kwargs) |
|
|
|
|
|
def AMD_S_Rec(**kwargs) -> AMDModel: |
|
|
return AMDModel_Rec( |
|
|
|
|
|
enc_num_layers = 8, |
|
|
enc_nhead = 8, |
|
|
enc_ndim = 64, |
|
|
|
|
|
diffusion_attn_head_dim = 64, |
|
|
diffusion_attn_num_heads = 16, |
|
|
diffusion_out_channels = 4, |
|
|
diffusion_num_layers = 12, |
|
|
**kwargs) |
|
|
|
|
|
def AMD_S_RecSplit(**kwargs) -> AMDModel: |
|
|
return AMDModel_Rec( |
|
|
|
|
|
enc_num_layers = 8, |
|
|
enc_nhead = 8, |
|
|
enc_ndim = 64, |
|
|
|
|
|
diffusion_attn_head_dim = 64, |
|
|
diffusion_attn_num_heads = 16, |
|
|
diffusion_out_channels = 4, |
|
|
diffusion_num_layers = 12, |
|
|
is_split = True, |
|
|
**kwargs) |
|
|
|
|
|
|
|
|
AMD_models = { |
|
|
"AMD_S": AMD_S, |
|
|
"AMD_L": AMD_L, |
|
|
"AMD_S_Rec": AMD_S_Rec, |
|
|
"AMD_S_RecSplit" : AMD_S_RecSplit, |
|
|
} |