|
|
import logging |
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
|
|
|
from .modules import ( |
|
|
film_modulate, |
|
|
unpatchify, |
|
|
PatchEmbed, |
|
|
PE_wrapper, |
|
|
TimestepEmbedder, |
|
|
FeedForward, |
|
|
RMSNorm, |
|
|
) |
|
|
from .span_mask import compute_mask_indices |
|
|
from .attention import Attention |
|
|
|
|
|
logger = logging.Logger(__file__) |
|
|
|
|
|
|
|
|
class AdaLN(nn.Module): |
|
|
def __init__(self, dim, ada_mode='ada', r=None, alpha=None): |
|
|
super().__init__() |
|
|
self.ada_mode = ada_mode |
|
|
self.scale_shift_table = None |
|
|
if ada_mode == 'ada': |
|
|
|
|
|
self.time_ada = nn.Linear(dim, 6 * dim, bias=True) |
|
|
elif ada_mode == 'ada_single': |
|
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) |
|
|
elif ada_mode in ['ada_sola', 'ada_sola_bias']: |
|
|
self.lora_a = nn.Linear(dim, r * 6, bias=False) |
|
|
self.lora_b = nn.Linear(r * 6, dim * 6, bias=False) |
|
|
self.scaling = alpha / r |
|
|
if ada_mode == 'ada_sola_bias': |
|
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def forward(self, time_token=None, time_ada=None): |
|
|
if self.ada_mode == 'ada': |
|
|
assert time_ada is None |
|
|
B = time_token.shape[0] |
|
|
time_ada = self.time_ada(time_token).reshape(B, 6, -1) |
|
|
elif self.ada_mode == 'ada_single': |
|
|
B = time_ada.shape[0] |
|
|
time_ada = time_ada.reshape(B, 6, -1) |
|
|
time_ada = self.scale_shift_table[None] + time_ada |
|
|
elif self.ada_mode in ['ada_sola', 'ada_sola_bias']: |
|
|
B = time_ada.shape[0] |
|
|
time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling |
|
|
time_ada = time_ada + time_ada_lora |
|
|
time_ada = time_ada.reshape(B, 6, -1) |
|
|
if self.scale_shift_table is not None: |
|
|
time_ada = self.scale_shift_table[None] + time_ada |
|
|
else: |
|
|
raise NotImplementedError |
|
|
return time_ada |
|
|
|
|
|
|
|
|
class DiTBlock(nn.Module): |
|
|
""" |
|
|
A modified PixArt block with adaptive layer norm (adaLN-single) conditioning. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
context_dim=None, |
|
|
num_heads=8, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=False, |
|
|
qk_scale=None, |
|
|
qk_norm=None, |
|
|
act_layer='gelu', |
|
|
norm_layer=nn.LayerNorm, |
|
|
time_fusion='none', |
|
|
ada_sola_rank=None, |
|
|
ada_sola_alpha=None, |
|
|
skip=False, |
|
|
skip_norm=False, |
|
|
rope_mode='none', |
|
|
context_norm=False, |
|
|
use_checkpoint=False |
|
|
): |
|
|
|
|
|
super().__init__() |
|
|
self.norm1 = norm_layer(dim) |
|
|
self.attn = Attention( |
|
|
dim=dim, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
qk_norm=qk_norm, |
|
|
rope_mode=rope_mode |
|
|
) |
|
|
|
|
|
if context_dim is not None: |
|
|
self.use_context = True |
|
|
self.cross_attn = Attention( |
|
|
dim=dim, |
|
|
num_heads=num_heads, |
|
|
context_dim=context_dim, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
qk_norm=qk_norm, |
|
|
rope_mode='none' |
|
|
) |
|
|
self.norm2 = norm_layer(dim) |
|
|
if context_norm: |
|
|
self.norm_context = norm_layer(context_dim) |
|
|
else: |
|
|
self.norm_context = nn.Identity() |
|
|
else: |
|
|
self.use_context = False |
|
|
|
|
|
self.norm3 = norm_layer(dim) |
|
|
self.mlp = FeedForward( |
|
|
dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0 |
|
|
) |
|
|
|
|
|
self.use_adanorm = True if time_fusion != 'token' else False |
|
|
if self.use_adanorm: |
|
|
self.adaln = AdaLN( |
|
|
dim, |
|
|
ada_mode=time_fusion, |
|
|
r=ada_sola_rank, |
|
|
alpha=ada_sola_alpha |
|
|
) |
|
|
if skip: |
|
|
self.skip_norm = norm_layer(2 * |
|
|
dim) if skip_norm else nn.Identity() |
|
|
self.skip_linear = nn.Linear(2 * dim, dim) |
|
|
else: |
|
|
self.skip_linear = None |
|
|
|
|
|
self.use_checkpoint = use_checkpoint |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
time_token=None, |
|
|
time_ada=None, |
|
|
skip=None, |
|
|
context=None, |
|
|
x_mask=None, |
|
|
context_mask=None, |
|
|
extras=None |
|
|
): |
|
|
if self.use_checkpoint: |
|
|
return checkpoint( |
|
|
self._forward, |
|
|
x, |
|
|
time_token, |
|
|
time_ada, |
|
|
skip, |
|
|
context, |
|
|
x_mask, |
|
|
context_mask, |
|
|
extras, |
|
|
use_reentrant=False |
|
|
) |
|
|
else: |
|
|
return self._forward( |
|
|
x, time_token, time_ada, skip, context, x_mask, context_mask, |
|
|
extras |
|
|
) |
|
|
|
|
|
def _forward( |
|
|
self, |
|
|
x, |
|
|
time_token=None, |
|
|
time_ada=None, |
|
|
skip=None, |
|
|
context=None, |
|
|
x_mask=None, |
|
|
context_mask=None, |
|
|
extras=None |
|
|
): |
|
|
B, T, C = x.shape |
|
|
if self.skip_linear is not None: |
|
|
assert skip is not None |
|
|
cat = torch.cat([x, skip], dim=-1) |
|
|
cat = self.skip_norm(cat) |
|
|
x = self.skip_linear(cat) |
|
|
|
|
|
if self.use_adanorm: |
|
|
time_ada = self.adaln(time_token, time_ada) |
|
|
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, |
|
|
gate_mlp) = time_ada.chunk(6, dim=1) |
|
|
|
|
|
|
|
|
if self.use_adanorm: |
|
|
x_norm = film_modulate( |
|
|
self.norm1(x), shift=shift_msa, scale=scale_msa |
|
|
) |
|
|
x = x + (1 - gate_msa) * self.attn( |
|
|
x_norm, context=None, context_mask=x_mask, extras=extras |
|
|
) |
|
|
else: |
|
|
x = x + self.attn( |
|
|
self.norm1(x), |
|
|
context=None, |
|
|
context_mask=x_mask, |
|
|
extras=extras |
|
|
) |
|
|
|
|
|
|
|
|
if self.use_context: |
|
|
assert context is not None |
|
|
x = x + self.cross_attn( |
|
|
x=self.norm2(x), |
|
|
context=self.norm_context(context), |
|
|
context_mask=context_mask, |
|
|
extras=extras |
|
|
) |
|
|
|
|
|
|
|
|
if self.use_adanorm: |
|
|
x_norm = film_modulate( |
|
|
self.norm3(x), shift=shift_mlp, scale=scale_mlp |
|
|
) |
|
|
x = x + (1 - gate_mlp) * self.mlp(x_norm) |
|
|
else: |
|
|
x = x + self.mlp(self.norm3(x)) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class FinalBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
patch_size, |
|
|
in_chans, |
|
|
img_size, |
|
|
input_type='2d', |
|
|
norm_layer=nn.LayerNorm, |
|
|
use_conv=True, |
|
|
use_adanorm=True |
|
|
): |
|
|
super().__init__() |
|
|
self.in_chans = in_chans |
|
|
self.img_size = img_size |
|
|
self.input_type = input_type |
|
|
|
|
|
self.norm = norm_layer(embed_dim) |
|
|
if use_adanorm: |
|
|
self.use_adanorm = True |
|
|
else: |
|
|
self.use_adanorm = False |
|
|
|
|
|
if input_type == '2d': |
|
|
self.patch_dim = patch_size**2 * in_chans |
|
|
self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) |
|
|
if use_conv: |
|
|
self.final_layer = nn.Conv2d( |
|
|
self.in_chans, self.in_chans, 3, padding=1 |
|
|
) |
|
|
else: |
|
|
self.final_layer = nn.Identity() |
|
|
|
|
|
elif input_type == '1d': |
|
|
self.patch_dim = patch_size * in_chans |
|
|
self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) |
|
|
if use_conv: |
|
|
self.final_layer = nn.Conv1d( |
|
|
self.in_chans, self.in_chans, 3, padding=1 |
|
|
) |
|
|
else: |
|
|
self.final_layer = nn.Identity() |
|
|
|
|
|
def forward(self, x, time_ada=None, extras=0): |
|
|
B, T, C = x.shape |
|
|
x = x[:, extras:, :] |
|
|
|
|
|
if self.use_adanorm: |
|
|
shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1) |
|
|
x = film_modulate(self.norm(x), shift, scale) |
|
|
else: |
|
|
x = self.norm(x) |
|
|
x = self.linear(x) |
|
|
x = unpatchify(x, self.in_chans, self.input_type, self.img_size) |
|
|
x = self.final_layer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class UDiT(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
img_size=224, |
|
|
patch_size=16, |
|
|
in_chans=3, |
|
|
input_type='2d', |
|
|
out_chans=None, |
|
|
embed_dim=768, |
|
|
depth=12, |
|
|
num_heads=12, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=False, |
|
|
qk_scale=None, |
|
|
qk_norm=None, |
|
|
act_layer='gelu', |
|
|
norm_layer='layernorm', |
|
|
context_norm=False, |
|
|
use_checkpoint=False, |
|
|
|
|
|
time_fusion='token', |
|
|
ada_sola_rank=None, |
|
|
ada_sola_alpha=None, |
|
|
cls_dim=None, |
|
|
|
|
|
context_dim=768, |
|
|
context_fusion='concat', |
|
|
context_max_length=128, |
|
|
context_pe_method='sinu', |
|
|
pe_method='abs', |
|
|
rope_mode='none', |
|
|
use_conv=True, |
|
|
skip=True, |
|
|
skip_norm=True |
|
|
): |
|
|
super().__init__() |
|
|
self.num_features = self.embed_dim = embed_dim |
|
|
|
|
|
|
|
|
self.in_chans = in_chans |
|
|
self.input_type = input_type |
|
|
if self.input_type == '2d': |
|
|
num_patches = (img_size[0] // |
|
|
patch_size) * (img_size[1] // patch_size) |
|
|
elif self.input_type == '1d': |
|
|
num_patches = img_size // patch_size |
|
|
self.patch_embed = PatchEmbed( |
|
|
patch_size=patch_size, |
|
|
in_chans=in_chans, |
|
|
embed_dim=embed_dim, |
|
|
input_type=input_type |
|
|
) |
|
|
out_chans = in_chans if out_chans is None else out_chans |
|
|
self.out_chans = out_chans |
|
|
|
|
|
|
|
|
self.rope = rope_mode |
|
|
self.x_pe = PE_wrapper( |
|
|
dim=embed_dim, method=pe_method, length=num_patches |
|
|
) |
|
|
|
|
|
logger.info(f'x position embedding: {pe_method}') |
|
|
logger.info(f'rope mode: {self.rope}') |
|
|
|
|
|
|
|
|
self.time_embed = TimestepEmbedder(embed_dim) |
|
|
self.time_fusion = time_fusion |
|
|
self.use_adanorm = False |
|
|
|
|
|
|
|
|
if cls_dim is not None: |
|
|
self.cls_embed = nn.Sequential( |
|
|
nn.Linear(cls_dim, embed_dim, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(embed_dim, embed_dim, bias=True), |
|
|
) |
|
|
else: |
|
|
self.cls_embed = None |
|
|
|
|
|
|
|
|
if time_fusion == 'token': |
|
|
|
|
|
self.extras = 2 if self.cls_embed else 1 |
|
|
self.time_pe = PE_wrapper( |
|
|
dim=embed_dim, method='abs', length=self.extras |
|
|
) |
|
|
elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']: |
|
|
self.use_adanorm = True |
|
|
|
|
|
self.time_act = nn.SiLU() |
|
|
self.extras = 0 |
|
|
self.time_ada_final = nn.Linear( |
|
|
embed_dim, 2 * embed_dim, bias=True |
|
|
) |
|
|
if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']: |
|
|
|
|
|
self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) |
|
|
else: |
|
|
self.time_ada = None |
|
|
else: |
|
|
raise NotImplementedError |
|
|
logger.info(f'time fusion mode: {self.time_fusion}') |
|
|
|
|
|
|
|
|
|
|
|
self.use_context = False |
|
|
self.context_cross = False |
|
|
self.context_max_length = context_max_length |
|
|
self.context_fusion = 'none' |
|
|
if context_dim is not None: |
|
|
self.use_context = True |
|
|
self.context_embed = nn.Sequential( |
|
|
nn.Linear(context_dim, embed_dim, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(embed_dim, embed_dim, bias=True), |
|
|
) |
|
|
self.context_fusion = context_fusion |
|
|
if context_fusion == 'concat' or context_fusion == 'joint': |
|
|
self.extras += context_max_length |
|
|
self.context_pe = PE_wrapper( |
|
|
dim=embed_dim, |
|
|
method=context_pe_method, |
|
|
length=context_max_length |
|
|
) |
|
|
|
|
|
context_dim = None |
|
|
elif context_fusion == 'cross': |
|
|
self.context_pe = PE_wrapper( |
|
|
dim=embed_dim, |
|
|
method=context_pe_method, |
|
|
length=context_max_length |
|
|
) |
|
|
self.context_cross = True |
|
|
context_dim = embed_dim |
|
|
else: |
|
|
raise NotImplementedError |
|
|
logger.info(f'context fusion mode: {context_fusion}') |
|
|
logger.info(f'context position embedding: {context_pe_method}') |
|
|
|
|
|
self.use_skip = skip |
|
|
|
|
|
|
|
|
if norm_layer == 'layernorm': |
|
|
norm_layer = nn.LayerNorm |
|
|
elif norm_layer == 'rmsnorm': |
|
|
norm_layer = RMSNorm |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
logger.info(f'use long skip connection: {skip}') |
|
|
self.in_blocks = nn.ModuleList([ |
|
|
DiTBlock( |
|
|
dim=embed_dim, |
|
|
context_dim=context_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
qk_norm=qk_norm, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
time_fusion=time_fusion, |
|
|
ada_sola_rank=ada_sola_rank, |
|
|
ada_sola_alpha=ada_sola_alpha, |
|
|
skip=False, |
|
|
skip_norm=False, |
|
|
rope_mode=self.rope, |
|
|
context_norm=context_norm, |
|
|
use_checkpoint=use_checkpoint |
|
|
) for _ in range(depth // 2) |
|
|
]) |
|
|
|
|
|
self.mid_block = DiTBlock( |
|
|
dim=embed_dim, |
|
|
context_dim=context_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
qk_norm=qk_norm, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
time_fusion=time_fusion, |
|
|
ada_sola_rank=ada_sola_rank, |
|
|
ada_sola_alpha=ada_sola_alpha, |
|
|
skip=False, |
|
|
skip_norm=False, |
|
|
rope_mode=self.rope, |
|
|
context_norm=context_norm, |
|
|
use_checkpoint=use_checkpoint |
|
|
) |
|
|
|
|
|
self.out_blocks = nn.ModuleList([ |
|
|
DiTBlock( |
|
|
dim=embed_dim, |
|
|
context_dim=context_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
qk_norm=qk_norm, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
time_fusion=time_fusion, |
|
|
ada_sola_rank=ada_sola_rank, |
|
|
ada_sola_alpha=ada_sola_alpha, |
|
|
skip=skip, |
|
|
skip_norm=skip_norm, |
|
|
rope_mode=self.rope, |
|
|
context_norm=context_norm, |
|
|
use_checkpoint=use_checkpoint |
|
|
) for _ in range(depth // 2) |
|
|
]) |
|
|
|
|
|
|
|
|
self.use_conv = use_conv |
|
|
self.final_block = FinalBlock( |
|
|
embed_dim=embed_dim, |
|
|
patch_size=patch_size, |
|
|
img_size=img_size, |
|
|
in_chans=out_chans, |
|
|
input_type=input_type, |
|
|
norm_layer=norm_layer, |
|
|
use_conv=use_conv, |
|
|
use_adanorm=self.use_adanorm |
|
|
) |
|
|
self.initialize_weights() |
|
|
|
|
|
def _init_ada(self): |
|
|
if self.time_fusion == 'ada': |
|
|
nn.init.constant_(self.time_ada_final.weight, 0) |
|
|
nn.init.constant_(self.time_ada_final.bias, 0) |
|
|
for block in self.in_blocks: |
|
|
nn.init.constant_(block.adaln.time_ada.weight, 0) |
|
|
nn.init.constant_(block.adaln.time_ada.bias, 0) |
|
|
nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0) |
|
|
nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0) |
|
|
for block in self.out_blocks: |
|
|
nn.init.constant_(block.adaln.time_ada.weight, 0) |
|
|
nn.init.constant_(block.adaln.time_ada.bias, 0) |
|
|
elif self.time_fusion == 'ada_single': |
|
|
nn.init.constant_(self.time_ada.weight, 0) |
|
|
nn.init.constant_(self.time_ada.bias, 0) |
|
|
nn.init.constant_(self.time_ada_final.weight, 0) |
|
|
nn.init.constant_(self.time_ada_final.bias, 0) |
|
|
elif self.time_fusion in ['ada_sola', 'ada_sola_bias']: |
|
|
nn.init.constant_(self.time_ada.weight, 0) |
|
|
nn.init.constant_(self.time_ada.bias, 0) |
|
|
nn.init.constant_(self.time_ada_final.weight, 0) |
|
|
nn.init.constant_(self.time_ada_final.bias, 0) |
|
|
for block in self.in_blocks: |
|
|
nn.init.kaiming_uniform_( |
|
|
block.adaln.lora_a.weight, a=math.sqrt(5) |
|
|
) |
|
|
nn.init.constant_(block.adaln.lora_b.weight, 0) |
|
|
nn.init.kaiming_uniform_( |
|
|
self.mid_block.adaln.lora_a.weight, a=math.sqrt(5) |
|
|
) |
|
|
nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0) |
|
|
for block in self.out_blocks: |
|
|
nn.init.kaiming_uniform_( |
|
|
block.adaln.lora_a.weight, a=math.sqrt(5) |
|
|
) |
|
|
nn.init.constant_(block.adaln.lora_b.weight, 0) |
|
|
|
|
|
def initialize_weights(self): |
|
|
|
|
|
def _basic_init(module): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
|
|
|
w = self.patch_embed.proj.weight.data |
|
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
nn.init.constant_(self.patch_embed.proj.bias, 0) |
|
|
|
|
|
|
|
|
if self.use_adanorm: |
|
|
self._init_ada() |
|
|
|
|
|
|
|
|
if self.context_cross: |
|
|
for block in self.in_blocks: |
|
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
|
nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0) |
|
|
nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0) |
|
|
for block in self.out_blocks: |
|
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
|
|
|
|
|
|
|
if self.cls_embed: |
|
|
if self.use_adanorm: |
|
|
nn.init.constant_(self.cls_embed[-1].weight, 0) |
|
|
nn.init.constant_(self.cls_embed[-1].bias, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_conv: |
|
|
nn.init.xavier_uniform_(self.final_block.final_layer.weight) |
|
|
nn.init.constant_(self.final_block.final_layer.bias, 0) |
|
|
|
|
|
def _concat_x_context(self, x, context, x_mask=None, context_mask=None): |
|
|
assert context.shape[-2] == self.context_max_length |
|
|
|
|
|
B = x.shape[0] |
|
|
|
|
|
if x_mask is None: |
|
|
x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() |
|
|
if context_mask is None: |
|
|
context_mask = torch.ones( |
|
|
B, context.shape[-2], device=context.device |
|
|
).bool() |
|
|
|
|
|
x_mask = torch.cat([context_mask, x_mask], dim=1) |
|
|
|
|
|
x = torch.cat((context, x), dim=1) |
|
|
return x, x_mask |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
timesteps, |
|
|
context, |
|
|
x_mask=None, |
|
|
context_mask=None, |
|
|
cls_token=None, |
|
|
controlnet_skips=None, |
|
|
): |
|
|
|
|
|
if timesteps.dim() == 0: |
|
|
timesteps = timesteps.expand(x.shape[0] |
|
|
).to(x.device, dtype=torch.long) |
|
|
|
|
|
x = self.patch_embed(x) |
|
|
x = self.x_pe(x) |
|
|
|
|
|
B, L, D = x.shape |
|
|
|
|
|
if self.use_context: |
|
|
context_token = self.context_embed(context) |
|
|
context_token = self.context_pe(context_token) |
|
|
if self.context_fusion == 'concat' or self.context_fusion == 'joint': |
|
|
x, x_mask = self._concat_x_context( |
|
|
x=x, |
|
|
context=context_token, |
|
|
x_mask=x_mask, |
|
|
context_mask=context_mask |
|
|
) |
|
|
context_token, context_mask = None, None |
|
|
else: |
|
|
context_token, context_mask = None, None |
|
|
|
|
|
time_token = self.time_embed(timesteps) |
|
|
if self.cls_embed: |
|
|
cls_token = self.cls_embed(cls_token) |
|
|
time_ada = None |
|
|
time_ada_final = None |
|
|
if self.use_adanorm: |
|
|
if self.cls_embed: |
|
|
time_token = time_token + cls_token |
|
|
time_token = self.time_act(time_token) |
|
|
time_ada_final = self.time_ada_final(time_token) |
|
|
if self.time_ada is not None: |
|
|
time_ada = self.time_ada(time_token) |
|
|
else: |
|
|
time_token = time_token.unsqueeze(dim=1) |
|
|
if self.cls_embed: |
|
|
cls_token = cls_token.unsqueeze(dim=1) |
|
|
time_token = torch.cat([time_token, cls_token], dim=1) |
|
|
time_token = self.time_pe(time_token) |
|
|
x = torch.cat((time_token, x), dim=1) |
|
|
if x_mask is not None: |
|
|
x_mask = torch.cat([ |
|
|
torch.ones(B, time_token.shape[1], |
|
|
device=x_mask.device).bool(), x_mask |
|
|
], |
|
|
dim=1) |
|
|
time_token = None |
|
|
|
|
|
skips = [] |
|
|
for blk in self.in_blocks: |
|
|
x = blk( |
|
|
x=x, |
|
|
time_token=time_token, |
|
|
time_ada=time_ada, |
|
|
skip=None, |
|
|
context=context_token, |
|
|
x_mask=x_mask, |
|
|
context_mask=context_mask, |
|
|
extras=self.extras |
|
|
) |
|
|
if self.use_skip: |
|
|
skips.append(x) |
|
|
|
|
|
x = self.mid_block( |
|
|
x=x, |
|
|
time_token=time_token, |
|
|
time_ada=time_ada, |
|
|
skip=None, |
|
|
context=context_token, |
|
|
x_mask=x_mask, |
|
|
context_mask=context_mask, |
|
|
extras=self.extras |
|
|
) |
|
|
for blk in self.out_blocks: |
|
|
if self.use_skip: |
|
|
skip = skips.pop() |
|
|
if controlnet_skips: |
|
|
|
|
|
skip = skip + controlnet_skips.pop() |
|
|
else: |
|
|
skip = None |
|
|
if controlnet_skips: |
|
|
|
|
|
x = x + controlnet_skips.pop() |
|
|
|
|
|
x = blk( |
|
|
x=x, |
|
|
time_token=time_token, |
|
|
time_ada=time_ada, |
|
|
skip=skip, |
|
|
context=context_token, |
|
|
x_mask=x_mask, |
|
|
context_mask=context_mask, |
|
|
extras=self.extras |
|
|
) |
|
|
|
|
|
x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class MaskDiT(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
model: UDiT, |
|
|
mae=False, |
|
|
mae_prob=0.5, |
|
|
mask_ratio=[0.25, 1.0], |
|
|
mask_span=10, |
|
|
): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.mae = mae |
|
|
if self.mae: |
|
|
out_channel = model.out_chans |
|
|
self.mask_embed = nn.Parameter(torch.zeros((out_channel))) |
|
|
self.mae_prob = mae_prob |
|
|
self.mask_ratio = mask_ratio |
|
|
self.mask_span = mask_span |
|
|
|
|
|
def random_masking(self, gt, mask_ratios, mae_mask_infer=None): |
|
|
B, D, L = gt.shape |
|
|
if mae_mask_infer is None: |
|
|
|
|
|
mask_ratios = mask_ratios.cpu().numpy() |
|
|
mask = compute_mask_indices( |
|
|
shape=[B, L], |
|
|
padding_mask=None, |
|
|
mask_prob=mask_ratios, |
|
|
mask_length=self.mask_span, |
|
|
mask_type="static", |
|
|
mask_other=0.0, |
|
|
min_masks=1, |
|
|
no_overlap=False, |
|
|
min_space=0, |
|
|
) |
|
|
mask = mask.unsqueeze(1).expand_as(gt) |
|
|
else: |
|
|
mask = mae_mask_infer |
|
|
mask = mask.expand_as(gt) |
|
|
gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask] |
|
|
return gt, mask.type_as(gt) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
timesteps, |
|
|
context, |
|
|
x_mask=None, |
|
|
context_mask=None, |
|
|
cls_token=None, |
|
|
gt=None, |
|
|
mae_mask_infer=None, |
|
|
forward_model=True |
|
|
): |
|
|
|
|
|
mae_mask = torch.ones_like(x) |
|
|
if self.mae: |
|
|
if gt is not None: |
|
|
B, D, L = gt.shape |
|
|
mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio |
|
|
).to(gt.device) |
|
|
gt, mae_mask = self.random_masking( |
|
|
gt, mask_ratios, mae_mask_infer |
|
|
) |
|
|
|
|
|
if mae_mask_infer is None: |
|
|
|
|
|
mae_batch = torch.rand(B) < self.mae_prob |
|
|
gt[~mae_batch] = self.mask_embed.view( |
|
|
1, D, 1 |
|
|
).expand_as(gt)[~mae_batch] |
|
|
mae_mask[~mae_batch] = 1.0 |
|
|
else: |
|
|
B, D, L = x.shape |
|
|
gt = self.mask_embed.view(1, D, 1).expand_as(x) |
|
|
x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1) |
|
|
|
|
|
if forward_model: |
|
|
x = self.model( |
|
|
x=x, |
|
|
timesteps=timesteps, |
|
|
context=context, |
|
|
x_mask=x_mask, |
|
|
context_mask=context_mask, |
|
|
cls_token=cls_token |
|
|
) |
|
|
|
|
|
return x, mae_mask |
|
|
|