File size: 5,741 Bytes
859ee84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import sys
sys.path.append("/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/src")

from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers import UNet2DModel

import torch
import torch.nn as nn

class DiffusersUNet(nn.Module):
    def __init__(self,     

        # Model itself

        in_channels=4,

        out_channels=4,

        attention_head_dim=8,

        block_out_channels=320,

        # cross attention condition

        cross_attention_dim=None, # 768

        encoder_hid_dim=None, # 1024

        # film condition

        global_additional_cond_dim=None,

        **kwargs,

    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.cross_attention_dim = cross_attention_dim
        self.block_out_channels = (block_out_channels, block_out_channels*2, block_out_channels*4, block_out_channels*8) 
        self.attention_head_dim = attention_head_dim
        self.global_additional_cond_dim = global_additional_cond_dim
        self.encoder_hid_dim = encoder_hid_dim
        if(self.cross_attention_dim is not None):
            self.down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D')
            # self.down_block_types=('SimpleCrossAttnDownBlock2D', 'SimpleCrossAttnDownBlock2D', 'SimpleCrossAttnDownBlock2D', 'DownBlock2D')
            self.mid_block_type='UNetMidBlock2DCrossAttn'
            self.up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D')
            # self.up_block_types=('UpBlock2D', 'SimpleCrossAttnUpBlock2D', 'SimpleCrossAttnUpBlock2D', 'SimpleCrossAttnUpBlock2D')
        else:
            self.down_block_types=('DownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D')
            self.mid_block_type='UNetMidBlock2DCrossAttn'
            self.up_block_types=('AttnUpBlock2D', 'AttnUpBlock2D', 'AttnUpBlock2D', 'UpBlock2D')
                     
        self.model = UNet2DConditionModel(
                        in_channels=self.in_channels,
                        out_channels=self.out_channels,
                        block_out_channels=self.block_out_channels,
                        cross_attention_dim=self.cross_attention_dim,
                        attention_head_dim = self.attention_head_dim,
                        global_additional_cond_dim = self.global_additional_cond_dim,
                        encoder_hid_dim = self.encoder_hid_dim,
                        down_block_types=self.down_block_types,
                        mid_block_type=self.mid_block_type,
                        up_block_types=self.up_block_types,
                        **kwargs)
                        
        print(self.model)

    def forward(self, x, timesteps, context=None, context_attn_mask=None, y=None, **kwargs):
        if(self.cross_attention_dim is None):
            assert context is None, "The cross attention dimension is None. So you are now allowed to use context as condition"
        else:
            assert context is not None and context_attn_mask is not None, "You need to provide context matrix"

        if(self.global_additional_cond_dim is None):
            assert y is None, "The global additional cond dimension is None. So you are now allowed to use y as condition"
        else:
            assert y is not None, "You need to provide a global additional cond"
        
        if(y is not None and len(y.size()) == 3):
            y = y.squeeze(1)

        return self.model(
            sample=x,
            timestep=timesteps,
            global_condition=y,
            encoder_hidden_states = context,
            encoder_attention_mask=context_attn_mask,
        ).sample

def test():
    ###################################################
    # Have both global cond and no encoder hidden state
    unet = DiffusersUNet().cuda()

    sample_input=torch.randn(3, 4, 256, 16).cuda()
    timestep=torch.tensor([1,2,3]).cuda()
    global_input=torch.randn(3, 512).cuda()
    encoder_hidden_states = torch.randn((3, 17, 1024)).cuda()
    attention_mask = torch.zeros((3, 17)).cuda()

    output = unet(x=sample_input, timesteps=timestep, y=global_input, context=encoder_hidden_states, context_attn_mask=attention_mask)
    print(output.size())

    ###################################################
    # No global cond and no encoder hidden state
    unet = DiffusersUNet(cross_attention_dim=None, global_additional_cond_dim=None).cuda()

    sample_input=torch.randn(3, 4, 256, 16).cuda()
    timestep=torch.tensor([1,2,3]).cuda()

    output = unet(x=sample_input, timesteps=timestep)
    print(output.size())

    ###################################################
    # No encoder_hidden_state
    unet = DiffusersUNet(cross_attention_dim=None).cuda()

    sample_input=torch.randn(3, 4, 256, 16).cuda()
    timestep=torch.tensor([1,2,3]).cuda()
    global_input=torch.randn(3, 512).cuda()

    output = unet(x=sample_input, timesteps=timestep, y=global_input)
    print(output.size())

    ###################################################
    # No global cond
    unet = DiffusersUNet(global_additional_cond_dim=None).cuda()

    sample_input=torch.randn(3, 4, 256, 16).cuda()
    timestep=torch.tensor([1,2,3]).cuda()
    encoder_hidden_states = torch.randn((3, 17, 1024)).cuda()
    attention_mask = torch.zeros((3, 17)).cuda()

    output = unet(x=sample_input, timesteps=timestep, context=encoder_hidden_states, context_attn_mask=attention_mask)
    print(output.size())

    

if __name__ == "__main__":
    test()