| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from ...configuration_utils import ConfigMixin, register_to_config |
| | from ...models.modeling_utils import ModelMixin |
| | from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm |
| |
|
| |
|
| | class WuerstchenDiffNeXt(ModelMixin, ConfigMixin): |
| | @register_to_config |
| | def __init__( |
| | self, |
| | c_in=4, |
| | c_out=4, |
| | c_r=64, |
| | patch_size=2, |
| | c_cond=1024, |
| | c_hidden=[320, 640, 1280, 1280], |
| | nhead=[-1, 10, 20, 20], |
| | blocks=[4, 4, 14, 4], |
| | level_config=["CT", "CTA", "CTA", "CTA"], |
| | inject_effnet=[False, True, True, True], |
| | effnet_embd=16, |
| | clip_embd=1024, |
| | kernel_size=3, |
| | dropout=0.1, |
| | ): |
| | super().__init__() |
| | self.c_r = c_r |
| | self.c_cond = c_cond |
| | if not isinstance(dropout, list): |
| | dropout = [dropout] * len(c_hidden) |
| |
|
| | |
| | self.clip_mapper = nn.Linear(clip_embd, c_cond) |
| | self.effnet_mappers = nn.ModuleList( |
| | [ |
| | nn.Conv2d(effnet_embd, c_cond, kernel_size=1) if inject else None |
| | for inject in inject_effnet + list(reversed(inject_effnet)) |
| | ] |
| | ) |
| | self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) |
| |
|
| | self.embedding = nn.Sequential( |
| | nn.PixelUnshuffle(patch_size), |
| | nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), |
| | WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), |
| | ) |
| |
|
| | def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): |
| | if block_type == "C": |
| | return ResBlockStageB(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) |
| | elif block_type == "A": |
| | return AttnBlock(c_hidden, c_cond, nhead, self_attn=True, dropout=dropout) |
| | elif block_type == "T": |
| | return TimestepBlock(c_hidden, c_r) |
| | else: |
| | raise ValueError(f"Block type {block_type} not supported") |
| |
|
| | |
| | |
| | self.down_blocks = nn.ModuleList() |
| | for i in range(len(c_hidden)): |
| | down_block = nn.ModuleList() |
| | if i > 0: |
| | down_block.append( |
| | nn.Sequential( |
| | WuerstchenLayerNorm(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), |
| | nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), |
| | ) |
| | ) |
| | for _ in range(blocks[i]): |
| | for block_type in level_config[i]: |
| | c_skip = c_cond if inject_effnet[i] else 0 |
| | down_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) |
| | self.down_blocks.append(down_block) |
| |
|
| | |
| | self.up_blocks = nn.ModuleList() |
| | for i in reversed(range(len(c_hidden))): |
| | up_block = nn.ModuleList() |
| | for j in range(blocks[i]): |
| | for k, block_type in enumerate(level_config[i]): |
| | c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 |
| | c_skip += c_cond if inject_effnet[i] else 0 |
| | up_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) |
| | if i > 0: |
| | up_block.append( |
| | nn.Sequential( |
| | WuerstchenLayerNorm(c_hidden[i], elementwise_affine=False, eps=1e-6), |
| | nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), |
| | ) |
| | ) |
| | self.up_blocks.append(up_block) |
| |
|
| | |
| | self.clf = nn.Sequential( |
| | WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), |
| | nn.Conv2d(c_hidden[0], 2 * c_out * (patch_size**2), kernel_size=1), |
| | nn.PixelShuffle(patch_size), |
| | ) |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, m): |
| | |
| | if isinstance(m, (nn.Conv2d, nn.Linear)): |
| | nn.init.xavier_uniform_(m.weight) |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | for mapper in self.effnet_mappers: |
| | if mapper is not None: |
| | nn.init.normal_(mapper.weight, std=0.02) |
| | nn.init.normal_(self.clip_mapper.weight, std=0.02) |
| | nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) |
| | nn.init.constant_(self.clf[1].weight, 0) |
| |
|
| | |
| | for level_block in self.down_blocks + self.up_blocks: |
| | for block in level_block: |
| | if isinstance(block, ResBlockStageB): |
| | block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks)) |
| | elif isinstance(block, TimestepBlock): |
| | nn.init.constant_(block.mapper.weight, 0) |
| |
|
| | def gen_r_embedding(self, r, max_positions=10000): |
| | r = r * max_positions |
| | half_dim = self.c_r // 2 |
| | emb = math.log(max_positions) / (half_dim - 1) |
| | emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() |
| | emb = r[:, None] * emb[None, :] |
| | emb = torch.cat([emb.sin(), emb.cos()], dim=1) |
| | if self.c_r % 2 == 1: |
| | emb = nn.functional.pad(emb, (0, 1), mode="constant") |
| | return emb.to(dtype=r.dtype) |
| |
|
| | def gen_c_embeddings(self, clip): |
| | clip = self.clip_mapper(clip) |
| | clip = self.seq_norm(clip) |
| | return clip |
| |
|
| | def _down_encode(self, x, r_embed, effnet, clip=None): |
| | level_outputs = [] |
| | for i, down_block in enumerate(self.down_blocks): |
| | effnet_c = None |
| | for block in down_block: |
| | if isinstance(block, ResBlockStageB): |
| | if effnet_c is None and self.effnet_mappers[i] is not None: |
| | dtype = effnet.dtype |
| | effnet_c = self.effnet_mappers[i]( |
| | nn.functional.interpolate( |
| | effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True |
| | ).to(dtype) |
| | ) |
| | skip = effnet_c if self.effnet_mappers[i] is not None else None |
| | x = block(x, skip) |
| | elif isinstance(block, AttnBlock): |
| | x = block(x, clip) |
| | elif isinstance(block, TimestepBlock): |
| | x = block(x, r_embed) |
| | else: |
| | x = block(x) |
| | level_outputs.insert(0, x) |
| | return level_outputs |
| |
|
| | def _up_decode(self, level_outputs, r_embed, effnet, clip=None): |
| | x = level_outputs[0] |
| | for i, up_block in enumerate(self.up_blocks): |
| | effnet_c = None |
| | for j, block in enumerate(up_block): |
| | if isinstance(block, ResBlockStageB): |
| | if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None: |
| | dtype = effnet.dtype |
| | effnet_c = self.effnet_mappers[len(self.down_blocks) + i]( |
| | nn.functional.interpolate( |
| | effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True |
| | ).to(dtype) |
| | ) |
| | skip = level_outputs[i] if j == 0 and i > 0 else None |
| | if effnet_c is not None: |
| | if skip is not None: |
| | skip = torch.cat([skip, effnet_c], dim=1) |
| | else: |
| | skip = effnet_c |
| | x = block(x, skip) |
| | elif isinstance(block, AttnBlock): |
| | x = block(x, clip) |
| | elif isinstance(block, TimestepBlock): |
| | x = block(x, r_embed) |
| | else: |
| | x = block(x) |
| | return x |
| |
|
| | def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True): |
| | if x_cat is not None: |
| | x = torch.cat([x, x_cat], dim=1) |
| | |
| | r_embed = self.gen_r_embedding(r) |
| | if clip is not None: |
| | clip = self.gen_c_embeddings(clip) |
| |
|
| | |
| | x_in = x |
| | x = self.embedding(x) |
| | level_outputs = self._down_encode(x, r_embed, effnet, clip) |
| | x = self._up_decode(level_outputs, r_embed, effnet, clip) |
| | a, b = self.clf(x).chunk(2, dim=1) |
| | b = b.sigmoid() * (1 - eps * 2) + eps |
| | if return_noise: |
| | return (x_in - a) / b |
| | else: |
| | return a, b |
| |
|
| |
|
| | class ResBlockStageB(nn.Module): |
| | def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): |
| | super().__init__() |
| | self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) |
| | self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) |
| | self.channelwise = nn.Sequential( |
| | nn.Linear(c + c_skip, c * 4), |
| | nn.GELU(), |
| | GlobalResponseNorm(c * 4), |
| | nn.Dropout(dropout), |
| | nn.Linear(c * 4, c), |
| | ) |
| |
|
| | def forward(self, x, x_skip=None): |
| | x_res = x |
| | x = self.norm(self.depthwise(x)) |
| | if x_skip is not None: |
| | x = torch.cat([x, x_skip], dim=1) |
| | x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
| | return x + x_res |
| |
|