| |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| from typing import Optional |
|
|
| from michelangelo.models.modules.checkpoint import checkpoint |
| from michelangelo.models.modules.transformer_blocks import ( |
| init_linear, |
| MLP, |
| MultiheadCrossAttention, |
| MultiheadAttention, |
| ResidualAttentionBlock |
| ) |
|
|
|
|
| class AdaLayerNorm(nn.Module): |
| def __init__(self, |
| device: torch.device, |
| dtype: torch.dtype, |
| width: int): |
|
|
| super().__init__() |
|
|
| self.silu = nn.SiLU(inplace=True) |
| self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype) |
| self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype) |
|
|
| def forward(self, x, timestep): |
| emb = self.linear(timestep) |
| scale, shift = torch.chunk(emb, 2, dim=2) |
| x = self.layernorm(x) * (1 + scale) + shift |
| return x |
|
|
|
|
| class DitBlock(nn.Module): |
| def __init__( |
| self, |
| *, |
| device: torch.device, |
| dtype: torch.dtype, |
| n_ctx: int, |
| width: int, |
| heads: int, |
| context_dim: int, |
| qkv_bias: bool = False, |
| init_scale: float = 1.0, |
| use_checkpoint: bool = False |
| ): |
| super().__init__() |
|
|
| self.use_checkpoint = use_checkpoint |
|
|
| self.attn = MultiheadAttention( |
| device=device, |
| dtype=dtype, |
| n_ctx=n_ctx, |
| width=width, |
| heads=heads, |
| init_scale=init_scale, |
| qkv_bias=qkv_bias |
| ) |
| self.ln_1 = AdaLayerNorm(device, dtype, width) |
|
|
| if context_dim is not None: |
| self.ln_2 = AdaLayerNorm(device, dtype, width) |
| self.cross_attn = MultiheadCrossAttention( |
| device=device, |
| dtype=dtype, |
| width=width, |
| heads=heads, |
| data_width=context_dim, |
| init_scale=init_scale, |
| qkv_bias=qkv_bias |
| ) |
|
|
| self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) |
| self.ln_3 = AdaLayerNorm(device, dtype, width) |
|
|
| def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): |
| return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint) |
|
|
| def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): |
| x = x + self.attn(self.ln_1(x, t)) |
| if context is not None: |
| x = x + self.cross_attn(self.ln_2(x, t), context) |
| x = x + self.mlp(self.ln_3(x, t)) |
| return x |
|
|
|
|
| class DiT(nn.Module): |
| def __init__( |
| self, |
| *, |
| device: Optional[torch.device], |
| dtype: Optional[torch.dtype], |
| n_ctx: int, |
| width: int, |
| layers: int, |
| heads: int, |
| context_dim: int, |
| init_scale: float = 0.25, |
| qkv_bias: bool = False, |
| use_checkpoint: bool = False |
| ): |
| super().__init__() |
| self.n_ctx = n_ctx |
| self.width = width |
| self.layers = layers |
|
|
| self.resblocks = nn.ModuleList( |
| [ |
| DitBlock( |
| device=device, |
| dtype=dtype, |
| n_ctx=n_ctx, |
| width=width, |
| heads=heads, |
| context_dim=context_dim, |
| qkv_bias=qkv_bias, |
| init_scale=init_scale, |
| use_checkpoint=use_checkpoint |
| ) |
| for _ in range(layers) |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): |
| for block in self.resblocks: |
| x = block(x, t, context) |
| return x |
|
|
|
|
| class UNetDiffusionTransformer(nn.Module): |
| def __init__( |
| self, |
| *, |
| device: Optional[torch.device], |
| dtype: Optional[torch.dtype], |
| n_ctx: int, |
| width: int, |
| layers: int, |
| heads: int, |
| init_scale: float = 0.25, |
| qkv_bias: bool = False, |
| skip_ln: bool = False, |
| use_checkpoint: bool = False |
| ): |
| super().__init__() |
|
|
| self.n_ctx = n_ctx |
| self.width = width |
| self.layers = layers |
|
|
| self.encoder = nn.ModuleList() |
| for _ in range(layers): |
| resblock = ResidualAttentionBlock( |
| device=device, |
| dtype=dtype, |
| n_ctx=n_ctx, |
| width=width, |
| heads=heads, |
| init_scale=init_scale, |
| qkv_bias=qkv_bias, |
| use_checkpoint=use_checkpoint |
| ) |
| self.encoder.append(resblock) |
|
|
| self.middle_block = ResidualAttentionBlock( |
| device=device, |
| dtype=dtype, |
| n_ctx=n_ctx, |
| width=width, |
| heads=heads, |
| init_scale=init_scale, |
| qkv_bias=qkv_bias, |
| use_checkpoint=use_checkpoint |
| ) |
|
|
| self.decoder = nn.ModuleList() |
| for _ in range(layers): |
| resblock = ResidualAttentionBlock( |
| device=device, |
| dtype=dtype, |
| n_ctx=n_ctx, |
| width=width, |
| heads=heads, |
| init_scale=init_scale, |
| qkv_bias=qkv_bias, |
| use_checkpoint=use_checkpoint |
| ) |
| linear = nn.Linear(width * 2, width, device=device, dtype=dtype) |
| init_linear(linear, init_scale) |
|
|
| layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None |
|
|
| self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) |
|
|
| def forward(self, x: torch.Tensor): |
|
|
| enc_outputs = [] |
| for block in self.encoder: |
| x = block(x) |
| enc_outputs.append(x) |
|
|
| x = self.middle_block(x) |
|
|
| for i, (resblock, linear, layer_norm) in enumerate(self.decoder): |
| x = torch.cat([enc_outputs.pop(), x], dim=-1) |
| x = linear(x) |
|
|
| if layer_norm is not None: |
| x = layer_norm(x) |
|
|
| x = resblock(x) |
|
|
| return x |
|
|