File size: 3,697 Bytes
1fdf2ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | from transformers import PretrainedConfig, PreTrainedModel
from torch import nn, tensor, concat
from diffusers.models.embeddings import get_timestep_embedding
import torch
class T5DiffusionXLTextEncoderMergerConfig(PretrainedConfig):
def __init__(self,
num_layers: int = 4,
dim_timestep_embeds: int = 16,
seq_len: int = 77,
channels_sdxl: int = 2048,
channels_t5: int = 4096,
channels_pooled: int = 1280,
**kwargs):
super().__init__(**kwargs)
self.num_layers = num_layers
self.dim_timestep_embeds = dim_timestep_embeds
self.seq_len = seq_len
self.channels_sdxl = channels_sdxl
self.channels_t5 = channels_t5
self.channels_pooled = channels_pooled
class T5DiffusionXLTextEncoderMerger(PreTrainedModel, nn.Module):
def __init__(self, config: T5DiffusionXLTextEncoderMergerConfig):
super().__init__(config)
self._last_timestep = 0
channels_concat = config.channels_sdxl + config.channels_t5
self.block_forward1 = nn.Sequential(
nn.Linear(channels_concat, channels_concat),
nn.LayerNorm([config.seq_len, channels_concat],
elementwise_affine=False))
layers = []
for _ in range(config.num_layers - 1):
layers.append(nn.Linear(channels_concat, channels_concat))
layers.append(nn.SiLU())
layers.append(nn.Linear(channels_concat, config.channels_sdxl))
layers.append(nn.Tanh())
self.block_forward2 = nn.Sequential(*layers)
self.block_modulate_by_pooled = nn.Sequential(
nn.Linear(config.channels_pooled, 512, bias=False), nn.SiLU(),
nn.Linear(512,
config.seq_len *
(channels_concat * 2 + config.channels_sdxl),
bias=False))
self.block_modulate_by_timestep = nn.Sequential(
nn.Linear(config.dim_timestep_embeds, 512, bias=False), nn.SiLU(),
nn.Linear(512,
config.seq_len *
(channels_concat * 2 + config.channels_sdxl),
bias=False))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.normal_(0, 0.1)
if module.bias is not None:
module.bias.zero_()
def forward(self, embeds_t5, embeds_sdxl, pooled_embeds_sdxl):
batch_size = embeds_sdxl.size(0)
assert batch_size == embeds_sdxl.size(0) == pooled_embeds_sdxl.size(0)
channels_sdxl = self.config.channels_sdxl
channels_concat = self.config.channels_t5 + channels_sdxl
seq_len = self.config.seq_len
timestep_embeds = get_timestep_embedding(
tensor([self._last_timestep]),
embedding_dim=self.config.dim_timestep_embeds).repeat(
batch_size, 1)
modulation = self.block_modulate_by_timestep(
timestep_embeds) + self.block_modulate_by_pooled(pooled_embeds_sdxl)
gamma, beta, zeta = [
m.view(batch_size, seq_len, -1) for m in modulation.split([
seq_len * channels_concat, seq_len * channels_concat, seq_len *
channels_sdxl
],
dim=1)
]
output = (gamma + 1) * self.block_forward1(
concat((embeds_t5, embeds_sdxl), dim=2)) + beta
output = (zeta + 1) * self.block_forward2(output)
output += embeds_sdxl
return {"output": output}
def set_timestep(self, timestep: int):
self._last_timestep = timestep
|