| |
|
|
| import torch |
| from torch import nn, sigmoid |
| from torch.nn import ( |
| LayerNorm, |
| Linear, |
| Module, |
| ModuleList, |
| Sequential, |
| ) |
|
|
| from .vb_layers_attentionv2 import AttentionPairBias |
| from .vb_modules_utils import LinearNoBias, SwiGLU, default |
|
|
|
|
| class AdaLN(Module): |
| """Algorithm 26""" |
|
|
| def __init__(self, dim, dim_single_cond): |
| super().__init__() |
| self.a_norm = LayerNorm(dim, elementwise_affine=False, bias=False) |
| self.s_norm = LayerNorm(dim_single_cond, bias=False) |
| self.s_scale = Linear(dim_single_cond, dim) |
| self.s_bias = LinearNoBias(dim_single_cond, dim) |
|
|
| def forward(self, a, s): |
| a = self.a_norm(a) |
| s = self.s_norm(s) |
| a = sigmoid(self.s_scale(s)) * a + self.s_bias(s) |
| return a |
|
|
|
|
| class ConditionedTransitionBlock(Module): |
| """Algorithm 25""" |
|
|
| def __init__(self, dim_single, dim_single_cond, expansion_factor=2): |
| super().__init__() |
|
|
| self.adaln = AdaLN(dim_single, dim_single_cond) |
|
|
| dim_inner = int(dim_single * expansion_factor) |
| self.swish_gate = Sequential( |
| LinearNoBias(dim_single, dim_inner * 2), |
| SwiGLU(), |
| ) |
| self.a_to_b = LinearNoBias(dim_single, dim_inner) |
| self.b_to_a = LinearNoBias(dim_inner, dim_single) |
|
|
| output_projection_linear = Linear(dim_single_cond, dim_single) |
| nn.init.zeros_(output_projection_linear.weight) |
| nn.init.constant_(output_projection_linear.bias, -2.0) |
|
|
| self.output_projection = nn.Sequential(output_projection_linear, nn.Sigmoid()) |
|
|
| def forward( |
| self, |
| a, |
| s, |
| ): |
| a = self.adaln(a, s) |
| b = self.swish_gate(a) * self.a_to_b(a) |
| a = self.output_projection(s) * self.b_to_a(b) |
|
|
| return a |
|
|
|
|
| class DiffusionTransformer(Module): |
| """Algorithm 23""" |
|
|
| def __init__( |
| self, |
| depth, |
| heads, |
| dim=384, |
| dim_single_cond=None, |
| pair_bias_attn=True, |
| activation_checkpointing=False, |
| post_layer_norm=False, |
| ): |
| super().__init__() |
| self.activation_checkpointing = activation_checkpointing |
| dim_single_cond = default(dim_single_cond, dim) |
| self.pair_bias_attn = pair_bias_attn |
|
|
| self.layers = ModuleList() |
| for _ in range(depth): |
| self.layers.append( |
| DiffusionTransformerLayer( |
| heads, |
| dim, |
| dim_single_cond, |
| post_layer_norm, |
| ) |
| ) |
|
|
| def forward( |
| self, |
| a, |
| s, |
| bias=None, |
| mask=None, |
| to_keys=None, |
| multiplicity=1, |
| ): |
| if self.pair_bias_attn: |
| B, N, M, D = bias.shape |
| L = len(self.layers) |
| bias = bias.view(B, N, M, L, D // L) |
|
|
| for i, layer in enumerate(self.layers): |
| if self.pair_bias_attn: |
| bias_l = bias[:, :, :, i] |
| else: |
| bias_l = None |
|
|
| if self.activation_checkpointing and self.training: |
| a = torch.utils.checkpoint.checkpoint( |
| layer, |
| a, |
| s, |
| bias_l, |
| mask, |
| to_keys, |
| multiplicity, |
| ) |
|
|
| else: |
| a = layer( |
| a, |
| s, |
| bias_l, |
| mask, |
| to_keys, |
| multiplicity, |
| ) |
| return a |
|
|
|
|
| class DiffusionTransformerLayer(Module): |
| """Algorithm 23""" |
|
|
| def __init__( |
| self, |
| heads, |
| dim=384, |
| dim_single_cond=None, |
| post_layer_norm=False, |
| ): |
| super().__init__() |
|
|
| dim_single_cond = default(dim_single_cond, dim) |
|
|
| self.adaln = AdaLN(dim, dim_single_cond) |
| self.pair_bias_attn = AttentionPairBias( |
| c_s=dim, num_heads=heads, compute_pair_bias=False |
| ) |
|
|
| self.output_projection_linear = Linear(dim_single_cond, dim) |
| nn.init.zeros_(self.output_projection_linear.weight) |
| nn.init.constant_(self.output_projection_linear.bias, -2.0) |
|
|
| self.output_projection = nn.Sequential( |
| self.output_projection_linear, nn.Sigmoid() |
| ) |
| self.transition = ConditionedTransitionBlock( |
| dim_single=dim, dim_single_cond=dim_single_cond |
| ) |
|
|
| if post_layer_norm: |
| self.post_lnorm = nn.LayerNorm(dim) |
| else: |
| self.post_lnorm = nn.Identity() |
|
|
| def forward( |
| self, |
| a, |
| s, |
| bias=None, |
| mask=None, |
| to_keys=None, |
| multiplicity=1, |
| ): |
| b = self.adaln(a, s) |
|
|
| k_in = b |
| if to_keys is not None: |
| k_in = to_keys(b) |
| mask = to_keys(mask.unsqueeze(-1)).squeeze(-1) |
|
|
| if self.pair_bias_attn: |
| b = self.pair_bias_attn( |
| s=b, |
| z=bias, |
| mask=mask, |
| multiplicity=multiplicity, |
| k_in=k_in, |
| ) |
| else: |
| b = self.no_pair_bias_attn(s=b, mask=mask, k_in=k_in) |
|
|
| b = self.output_projection(s) * b |
|
|
| a = a + b |
| a = a + self.transition(a, s) |
|
|
| a = self.post_lnorm(a) |
| return a |
|
|
|
|
| class AtomTransformer(Module): |
| """Algorithm 7""" |
|
|
| def __init__( |
| self, |
| attn_window_queries, |
| attn_window_keys, |
| **diffusion_transformer_kwargs, |
| ): |
| super().__init__() |
| self.attn_window_queries = attn_window_queries |
| self.attn_window_keys = attn_window_keys |
| self.diffusion_transformer = DiffusionTransformer( |
| **diffusion_transformer_kwargs |
| ) |
|
|
| def forward( |
| self, |
| q, |
| c, |
| bias, |
| to_keys, |
| mask, |
| multiplicity=1, |
| ): |
| W = self.attn_window_queries |
| H = self.attn_window_keys |
|
|
| B, N, D = q.shape |
| NW = N // W |
|
|
| |
| q = q.view((B * NW, W, -1)) |
| c = c.view((B * NW, W, -1)) |
| mask = mask.view(B * NW, W) |
| bias = bias.repeat_interleave(multiplicity, 0) |
| bias = bias.view((bias.shape[0] * NW, W, H, -1)) |
|
|
| to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1) |
|
|
| |
| q = self.diffusion_transformer( |
| a=q, |
| s=c, |
| bias=bias, |
| mask=mask.float(), |
| multiplicity=1, |
| to_keys=to_keys_new, |
| ) |
|
|
| q = q.view((B, NW * W, D)) |
| return q |
|
|