| |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Optional |
| from diffusers.models.embeddings import Timesteps |
| import math |
|
|
| from michelangelo.models.modules.transformer_blocks import MLP |
| from michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer |
|
|
|
|
| class ConditionalASLUDTDenoiser(nn.Module): |
|
|
| def __init__(self, *, |
| device: Optional[torch.device], |
| dtype: Optional[torch.dtype], |
| input_channels: int, |
| output_channels: int, |
| n_ctx: int, |
| width: int, |
| layers: int, |
| heads: int, |
| context_dim: int, |
| context_ln: bool = True, |
| skip_ln: bool = False, |
| init_scale: float = 0.25, |
| flip_sin_to_cos: bool = False, |
| use_checkpoint: bool = False): |
| super().__init__() |
|
|
| self.use_checkpoint = use_checkpoint |
|
|
| init_scale = init_scale * math.sqrt(1.0 / width) |
|
|
| self.backbone = UNetDiffusionTransformer( |
| device=device, |
| dtype=dtype, |
| n_ctx=n_ctx, |
| width=width, |
| layers=layers, |
| heads=heads, |
| skip_ln=skip_ln, |
| init_scale=init_scale, |
| use_checkpoint=use_checkpoint |
| ) |
| self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) |
| self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) |
| self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) |
|
|
| |
| self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0) |
| self.time_proj = MLP( |
| device=device, dtype=dtype, width=width, init_scale=init_scale |
| ) |
|
|
| self.context_embed = nn.Sequential( |
| nn.LayerNorm(context_dim, device=device, dtype=dtype), |
| nn.Linear(context_dim, width, device=device, dtype=dtype), |
| ) |
|
|
| if context_ln: |
| self.context_embed = nn.Sequential( |
| nn.LayerNorm(context_dim, device=device, dtype=dtype), |
| nn.Linear(context_dim, width, device=device, dtype=dtype), |
| ) |
| else: |
| self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype) |
|
|
| def forward(self, |
| model_input: torch.FloatTensor, |
| timestep: torch.LongTensor, |
| context: torch.FloatTensor): |
|
|
| r""" |
| Args: |
| model_input (torch.FloatTensor): [bs, n_data, c] |
| timestep (torch.LongTensor): [bs,] |
| context (torch.FloatTensor): [bs, context_tokens, c] |
| |
| Returns: |
| sample (torch.FloatTensor): [bs, n_data, c] |
| |
| """ |
|
|
| _, n_data, _ = model_input.shape |
|
|
| |
| t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) |
|
|
| |
| context = self.context_embed(context) |
|
|
| |
| x = self.input_proj(model_input) |
| x = torch.cat([t_emb, context, x], dim=1) |
| x = self.backbone(x) |
| x = self.ln_post(x) |
| x = x[:, -n_data:] |
| sample = self.output_proj(x) |
|
|
| return sample |
|
|
|
|
|
|