Hive / models /flowsep /latent_diffusion /modules /diffusers_unet.py
JusperLee's picture
update all code
859ee84
import sys
sys.path.append("/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/src")
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers import UNet2DModel
import torch
import torch.nn as nn
class DiffusersUNet(nn.Module):
def __init__(self,
# Model itself
in_channels=4,
out_channels=4,
attention_head_dim=8,
block_out_channels=320,
# cross attention condition
cross_attention_dim=None, # 768
encoder_hid_dim=None, # 1024
# film condition
global_additional_cond_dim=None,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.cross_attention_dim = cross_attention_dim
self.block_out_channels = (block_out_channels, block_out_channels*2, block_out_channels*4, block_out_channels*8)
self.attention_head_dim = attention_head_dim
self.global_additional_cond_dim = global_additional_cond_dim
self.encoder_hid_dim = encoder_hid_dim
if(self.cross_attention_dim is not None):
self.down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D')
# self.down_block_types=('SimpleCrossAttnDownBlock2D', 'SimpleCrossAttnDownBlock2D', 'SimpleCrossAttnDownBlock2D', 'DownBlock2D')
self.mid_block_type='UNetMidBlock2DCrossAttn'
self.up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D')
# self.up_block_types=('UpBlock2D', 'SimpleCrossAttnUpBlock2D', 'SimpleCrossAttnUpBlock2D', 'SimpleCrossAttnUpBlock2D')
else:
self.down_block_types=('DownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D')
self.mid_block_type='UNetMidBlock2DCrossAttn'
self.up_block_types=('AttnUpBlock2D', 'AttnUpBlock2D', 'AttnUpBlock2D', 'UpBlock2D')
self.model = UNet2DConditionModel(
in_channels=self.in_channels,
out_channels=self.out_channels,
block_out_channels=self.block_out_channels,
cross_attention_dim=self.cross_attention_dim,
attention_head_dim = self.attention_head_dim,
global_additional_cond_dim = self.global_additional_cond_dim,
encoder_hid_dim = self.encoder_hid_dim,
down_block_types=self.down_block_types,
mid_block_type=self.mid_block_type,
up_block_types=self.up_block_types,
**kwargs)
print(self.model)
def forward(self, x, timesteps, context=None, context_attn_mask=None, y=None, **kwargs):
if(self.cross_attention_dim is None):
assert context is None, "The cross attention dimension is None. So you are now allowed to use context as condition"
else:
assert context is not None and context_attn_mask is not None, "You need to provide context matrix"
if(self.global_additional_cond_dim is None):
assert y is None, "The global additional cond dimension is None. So you are now allowed to use y as condition"
else:
assert y is not None, "You need to provide a global additional cond"
if(y is not None and len(y.size()) == 3):
y = y.squeeze(1)
return self.model(
sample=x,
timestep=timesteps,
global_condition=y,
encoder_hidden_states = context,
encoder_attention_mask=context_attn_mask,
).sample
def test():
###################################################
# Have both global cond and no encoder hidden state
unet = DiffusersUNet().cuda()
sample_input=torch.randn(3, 4, 256, 16).cuda()
timestep=torch.tensor([1,2,3]).cuda()
global_input=torch.randn(3, 512).cuda()
encoder_hidden_states = torch.randn((3, 17, 1024)).cuda()
attention_mask = torch.zeros((3, 17)).cuda()
output = unet(x=sample_input, timesteps=timestep, y=global_input, context=encoder_hidden_states, context_attn_mask=attention_mask)
print(output.size())
###################################################
# No global cond and no encoder hidden state
unet = DiffusersUNet(cross_attention_dim=None, global_additional_cond_dim=None).cuda()
sample_input=torch.randn(3, 4, 256, 16).cuda()
timestep=torch.tensor([1,2,3]).cuda()
output = unet(x=sample_input, timesteps=timestep)
print(output.size())
###################################################
# No encoder_hidden_state
unet = DiffusersUNet(cross_attention_dim=None).cuda()
sample_input=torch.randn(3, 4, 256, 16).cuda()
timestep=torch.tensor([1,2,3]).cuda()
global_input=torch.randn(3, 512).cuda()
output = unet(x=sample_input, timesteps=timestep, y=global_input)
print(output.size())
###################################################
# No global cond
unet = DiffusersUNet(global_additional_cond_dim=None).cuda()
sample_input=torch.randn(3, 4, 256, 16).cuda()
timestep=torch.tensor([1,2,3]).cuda()
encoder_hidden_states = torch.randn((3, 17, 1024)).cuda()
attention_mask = torch.zeros((3, 17)).cuda()
output = unet(x=sample_input, timesteps=timestep, context=encoder_hidden_states, context_attn_mask=attention_mask)
print(output.size())
if __name__ == "__main__":
test()