Spaces:
Sleeping
Sleeping
| import sys | |
| sys.path.append("/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/src") | |
| from diffusers.models.unet_2d_condition import UNet2DConditionModel | |
| from diffusers import UNet2DModel | |
| import torch | |
| import torch.nn as nn | |
| class DiffusersUNet(nn.Module): | |
| def __init__(self, | |
| # Model itself | |
| in_channels=4, | |
| out_channels=4, | |
| attention_head_dim=8, | |
| block_out_channels=320, | |
| # cross attention condition | |
| cross_attention_dim=None, # 768 | |
| encoder_hid_dim=None, # 1024 | |
| # film condition | |
| global_additional_cond_dim=None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.cross_attention_dim = cross_attention_dim | |
| self.block_out_channels = (block_out_channels, block_out_channels*2, block_out_channels*4, block_out_channels*8) | |
| self.attention_head_dim = attention_head_dim | |
| self.global_additional_cond_dim = global_additional_cond_dim | |
| self.encoder_hid_dim = encoder_hid_dim | |
| if(self.cross_attention_dim is not None): | |
| self.down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') | |
| # self.down_block_types=('SimpleCrossAttnDownBlock2D', 'SimpleCrossAttnDownBlock2D', 'SimpleCrossAttnDownBlock2D', 'DownBlock2D') | |
| self.mid_block_type='UNetMidBlock2DCrossAttn' | |
| self.up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') | |
| # self.up_block_types=('UpBlock2D', 'SimpleCrossAttnUpBlock2D', 'SimpleCrossAttnUpBlock2D', 'SimpleCrossAttnUpBlock2D') | |
| else: | |
| self.down_block_types=('DownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D') | |
| self.mid_block_type='UNetMidBlock2DCrossAttn' | |
| self.up_block_types=('AttnUpBlock2D', 'AttnUpBlock2D', 'AttnUpBlock2D', 'UpBlock2D') | |
| self.model = UNet2DConditionModel( | |
| in_channels=self.in_channels, | |
| out_channels=self.out_channels, | |
| block_out_channels=self.block_out_channels, | |
| cross_attention_dim=self.cross_attention_dim, | |
| attention_head_dim = self.attention_head_dim, | |
| global_additional_cond_dim = self.global_additional_cond_dim, | |
| encoder_hid_dim = self.encoder_hid_dim, | |
| down_block_types=self.down_block_types, | |
| mid_block_type=self.mid_block_type, | |
| up_block_types=self.up_block_types, | |
| **kwargs) | |
| print(self.model) | |
| def forward(self, x, timesteps, context=None, context_attn_mask=None, y=None, **kwargs): | |
| if(self.cross_attention_dim is None): | |
| assert context is None, "The cross attention dimension is None. So you are now allowed to use context as condition" | |
| else: | |
| assert context is not None and context_attn_mask is not None, "You need to provide context matrix" | |
| if(self.global_additional_cond_dim is None): | |
| assert y is None, "The global additional cond dimension is None. So you are now allowed to use y as condition" | |
| else: | |
| assert y is not None, "You need to provide a global additional cond" | |
| if(y is not None and len(y.size()) == 3): | |
| y = y.squeeze(1) | |
| return self.model( | |
| sample=x, | |
| timestep=timesteps, | |
| global_condition=y, | |
| encoder_hidden_states = context, | |
| encoder_attention_mask=context_attn_mask, | |
| ).sample | |
| def test(): | |
| ################################################### | |
| # Have both global cond and no encoder hidden state | |
| unet = DiffusersUNet().cuda() | |
| sample_input=torch.randn(3, 4, 256, 16).cuda() | |
| timestep=torch.tensor([1,2,3]).cuda() | |
| global_input=torch.randn(3, 512).cuda() | |
| encoder_hidden_states = torch.randn((3, 17, 1024)).cuda() | |
| attention_mask = torch.zeros((3, 17)).cuda() | |
| output = unet(x=sample_input, timesteps=timestep, y=global_input, context=encoder_hidden_states, context_attn_mask=attention_mask) | |
| print(output.size()) | |
| ################################################### | |
| # No global cond and no encoder hidden state | |
| unet = DiffusersUNet(cross_attention_dim=None, global_additional_cond_dim=None).cuda() | |
| sample_input=torch.randn(3, 4, 256, 16).cuda() | |
| timestep=torch.tensor([1,2,3]).cuda() | |
| output = unet(x=sample_input, timesteps=timestep) | |
| print(output.size()) | |
| ################################################### | |
| # No encoder_hidden_state | |
| unet = DiffusersUNet(cross_attention_dim=None).cuda() | |
| sample_input=torch.randn(3, 4, 256, 16).cuda() | |
| timestep=torch.tensor([1,2,3]).cuda() | |
| global_input=torch.randn(3, 512).cuda() | |
| output = unet(x=sample_input, timesteps=timestep, y=global_input) | |
| print(output.size()) | |
| ################################################### | |
| # No global cond | |
| unet = DiffusersUNet(global_additional_cond_dim=None).cuda() | |
| sample_input=torch.randn(3, 4, 256, 16).cuda() | |
| timestep=torch.tensor([1,2,3]).cuda() | |
| encoder_hidden_states = torch.randn((3, 17, 1024)).cuda() | |
| attention_mask = torch.zeros((3, 17)).cuda() | |
| output = unet(x=sample_input, timesteps=timestep, context=encoder_hidden_states, context_attn_mask=attention_mask) | |
| print(output.size()) | |
| if __name__ == "__main__": | |
| test() |