Spaces:
Sleeping
Sleeping
File size: 5,741 Bytes
859ee84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import sys
sys.path.append("/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/src")
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers import UNet2DModel
import torch
import torch.nn as nn
class DiffusersUNet(nn.Module):
def __init__(self,
# Model itself
in_channels=4,
out_channels=4,
attention_head_dim=8,
block_out_channels=320,
# cross attention condition
cross_attention_dim=None, # 768
encoder_hid_dim=None, # 1024
# film condition
global_additional_cond_dim=None,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.cross_attention_dim = cross_attention_dim
self.block_out_channels = (block_out_channels, block_out_channels*2, block_out_channels*4, block_out_channels*8)
self.attention_head_dim = attention_head_dim
self.global_additional_cond_dim = global_additional_cond_dim
self.encoder_hid_dim = encoder_hid_dim
if(self.cross_attention_dim is not None):
self.down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D')
# self.down_block_types=('SimpleCrossAttnDownBlock2D', 'SimpleCrossAttnDownBlock2D', 'SimpleCrossAttnDownBlock2D', 'DownBlock2D')
self.mid_block_type='UNetMidBlock2DCrossAttn'
self.up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D')
# self.up_block_types=('UpBlock2D', 'SimpleCrossAttnUpBlock2D', 'SimpleCrossAttnUpBlock2D', 'SimpleCrossAttnUpBlock2D')
else:
self.down_block_types=('DownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D')
self.mid_block_type='UNetMidBlock2DCrossAttn'
self.up_block_types=('AttnUpBlock2D', 'AttnUpBlock2D', 'AttnUpBlock2D', 'UpBlock2D')
self.model = UNet2DConditionModel(
in_channels=self.in_channels,
out_channels=self.out_channels,
block_out_channels=self.block_out_channels,
cross_attention_dim=self.cross_attention_dim,
attention_head_dim = self.attention_head_dim,
global_additional_cond_dim = self.global_additional_cond_dim,
encoder_hid_dim = self.encoder_hid_dim,
down_block_types=self.down_block_types,
mid_block_type=self.mid_block_type,
up_block_types=self.up_block_types,
**kwargs)
print(self.model)
def forward(self, x, timesteps, context=None, context_attn_mask=None, y=None, **kwargs):
if(self.cross_attention_dim is None):
assert context is None, "The cross attention dimension is None. So you are now allowed to use context as condition"
else:
assert context is not None and context_attn_mask is not None, "You need to provide context matrix"
if(self.global_additional_cond_dim is None):
assert y is None, "The global additional cond dimension is None. So you are now allowed to use y as condition"
else:
assert y is not None, "You need to provide a global additional cond"
if(y is not None and len(y.size()) == 3):
y = y.squeeze(1)
return self.model(
sample=x,
timestep=timesteps,
global_condition=y,
encoder_hidden_states = context,
encoder_attention_mask=context_attn_mask,
).sample
def test():
###################################################
# Have both global cond and no encoder hidden state
unet = DiffusersUNet().cuda()
sample_input=torch.randn(3, 4, 256, 16).cuda()
timestep=torch.tensor([1,2,3]).cuda()
global_input=torch.randn(3, 512).cuda()
encoder_hidden_states = torch.randn((3, 17, 1024)).cuda()
attention_mask = torch.zeros((3, 17)).cuda()
output = unet(x=sample_input, timesteps=timestep, y=global_input, context=encoder_hidden_states, context_attn_mask=attention_mask)
print(output.size())
###################################################
# No global cond and no encoder hidden state
unet = DiffusersUNet(cross_attention_dim=None, global_additional_cond_dim=None).cuda()
sample_input=torch.randn(3, 4, 256, 16).cuda()
timestep=torch.tensor([1,2,3]).cuda()
output = unet(x=sample_input, timesteps=timestep)
print(output.size())
###################################################
# No encoder_hidden_state
unet = DiffusersUNet(cross_attention_dim=None).cuda()
sample_input=torch.randn(3, 4, 256, 16).cuda()
timestep=torch.tensor([1,2,3]).cuda()
global_input=torch.randn(3, 512).cuda()
output = unet(x=sample_input, timesteps=timestep, y=global_input)
print(output.size())
###################################################
# No global cond
unet = DiffusersUNet(global_additional_cond_dim=None).cuda()
sample_input=torch.randn(3, 4, 256, 16).cuda()
timestep=torch.tensor([1,2,3]).cuda()
encoder_hidden_states = torch.randn((3, 17, 1024)).cuda()
attention_mask = torch.zeros((3, 17)).cuda()
output = unet(x=sample_input, timesteps=timestep, context=encoder_hidden_states, context_attn_mask=attention_mask)
print(output.size())
if __name__ == "__main__":
test() |