File size: 3,697 Bytes
1fdf2ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from transformers import PretrainedConfig, PreTrainedModel
from torch import nn, tensor, concat
from diffusers.models.embeddings import get_timestep_embedding
import torch

class T5DiffusionXLTextEncoderMergerConfig(PretrainedConfig):

    def __init__(self,
                 num_layers: int = 4,
                 dim_timestep_embeds: int = 16,
                 seq_len: int = 77,
                 channels_sdxl: int = 2048,
                 channels_t5: int = 4096,
                 channels_pooled: int = 1280,
                 **kwargs):
        super().__init__(**kwargs)
        self.num_layers = num_layers
        self.dim_timestep_embeds = dim_timestep_embeds
        self.seq_len = seq_len
        self.channels_sdxl = channels_sdxl
        self.channels_t5 = channels_t5
        self.channels_pooled = channels_pooled


class T5DiffusionXLTextEncoderMerger(PreTrainedModel, nn.Module):

    def __init__(self, config: T5DiffusionXLTextEncoderMergerConfig):
        super().__init__(config)
        self._last_timestep = 0
        channels_concat = config.channels_sdxl + config.channels_t5
        self.block_forward1 = nn.Sequential(
            nn.Linear(channels_concat, channels_concat),
            nn.LayerNorm([config.seq_len, channels_concat],
                         elementwise_affine=False))

        layers = []
        for _ in range(config.num_layers - 1):
            layers.append(nn.Linear(channels_concat, channels_concat))
            layers.append(nn.SiLU())
        layers.append(nn.Linear(channels_concat, config.channels_sdxl))
        layers.append(nn.Tanh())
        self.block_forward2 = nn.Sequential(*layers)

        self.block_modulate_by_pooled = nn.Sequential(
            nn.Linear(config.channels_pooled, 512, bias=False), nn.SiLU(),
            nn.Linear(512,
                      config.seq_len *
                      (channels_concat * 2 + config.channels_sdxl),
                      bias=False))

        self.block_modulate_by_timestep = nn.Sequential(
            nn.Linear(config.dim_timestep_embeds, 512, bias=False), nn.SiLU(),
            nn.Linear(512,
                      config.seq_len *
                      (channels_concat * 2 + config.channels_sdxl),
                      bias=False))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.normal_(0, 0.1)
            if module.bias is not None:
                module.bias.zero_()

    def forward(self, embeds_t5, embeds_sdxl, pooled_embeds_sdxl):
        batch_size = embeds_sdxl.size(0)
        assert batch_size == embeds_sdxl.size(0) == pooled_embeds_sdxl.size(0)
        channels_sdxl = self.config.channels_sdxl
        channels_concat = self.config.channels_t5 + channels_sdxl
        seq_len = self.config.seq_len
        timestep_embeds = get_timestep_embedding(
            tensor([self._last_timestep]),
            embedding_dim=self.config.dim_timestep_embeds).repeat(
                batch_size, 1)
        modulation = self.block_modulate_by_timestep(
            timestep_embeds) + self.block_modulate_by_pooled(pooled_embeds_sdxl)
        gamma, beta, zeta = [
            m.view(batch_size, seq_len, -1) for m in modulation.split([
                seq_len * channels_concat, seq_len * channels_concat, seq_len *
                channels_sdxl
            ],
                dim=1)
        ]
        output = (gamma + 1) * self.block_forward1(
            concat((embeds_t5, embeds_sdxl), dim=2)) + beta
        output = (zeta + 1) * self.block_forward2(output)
        output += embeds_sdxl
        return {"output": output}

    def set_timestep(self, timestep: int):
        self._last_timestep = timestep