# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n import math from functools import partial from typing import List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .config import TransformerConfig from .patcher import Patcher from .rope import RotaryEmbedding def gate(x, gate): return x * gate def modulate(x, shift, scale): return x * (1 + scale) + shift def get_nonlinearity(kind: str): return { "relu": F.relu, "gelu": F.gelu, "swiglu": None, "approx_gelu": partial(F.gelu, approximate="tanh"), "srelu": lambda x: F.relu(x) ** 2, "silu": F.silu, }[kind] class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()) return (output * self.weight).type_as(x) class ProjectionLayer(torch.nn.Module): def __init__( self, in_dim: int, out_dim: int, non_linearity: str, dropout: float, fc_bias: bool = False, ): super().__init__() self.swiglu = non_linearity == "swiglu" self.dropout = dropout self.w1 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias) self.w2 = torch.nn.Linear(out_dim, out_dim, bias=fc_bias) if self.swiglu: self.w3 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias) # non-linearity self.non_linearity = get_nonlinearity(non_linearity) def forward(self, x): hidden1 = self.w1(x) if self.swiglu: hidden3 = self.w3(x) hidden = F.silu(hidden1) * hidden3 else: hidden = self.non_linearity(hidden1) hidden = F.dropout(hidden, p=self.dropout, training=self.training) return self.w2(hidden) class Attention(nn.Module): def __init__( self, dim: int, head_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float = 1e-5, use_qk_norm: bool = False, fc_bias: bool = False, ): super().__init__() assert n_heads % n_kv_heads == 0 self.head_dim = head_dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.use_qk_norm = use_qk_norm self.wq = torch.nn.Linear(dim, n_heads * head_dim, bias=fc_bias) self.wk, self.wv = [ torch.nn.Linear( dim, n_kv_heads * head_dim, bias=fc_bias, ) for _ in range(2) ] self.wo = torch.nn.Linear( n_heads * head_dim, dim, bias=fc_bias, ) if self.use_qk_norm is True: self.q_norm = RMSNorm(head_dim, eps=norm_eps) self.k_norm = RMSNorm(head_dim, eps=norm_eps) def reshape_heads(self, x: torch.Tensor, heads: int) -> torch.Tensor: B, T, C = x.shape # B x T x C -> B x T x C/H x H x = x.reshape(B, T, C // heads, heads) # B x T x C/H x H -> B x H x T x C/H return x.permute(0, 3, 1, 2) def forward( self, x: torch.Tensor, cross_x: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, rope: Optional[RotaryEmbedding] = None, ): # x: B, T, E xq = self.wq(x) if cross_x is not None: xk, xv = self.wk(cross_x), self.wv(cross_x) else: xk, xv = self.wk(x), self.wv(x) xk = self.reshape_heads(xk, self.n_kv_heads) xv = self.reshape_heads(xv, self.n_kv_heads) xq = self.reshape_heads(xq, self.n_heads) if self.use_qk_norm: xq = self.q_norm(xq) xk = self.k_norm(xk) if rope is not None: xq = rope(xq, bhle=True) xk = rope(xk, bhle=True) attn_mask = None if key_padding_mask is not None: attn_mask = key_padding_mask[:, None, None, :] output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask) output = rearrange(output, "b h n d -> b n (h d)") return self.wo(output) class FeedForward(torch.nn.Module): def __init__( self, dim: int, hidden_dim: int, ffn_dim_multiplier: float, multiple_of: int, dropout: float, non_linearity: str = "swiglu", fc_bias: bool = False, ): super().__init__() self.dropout = dropout self.swiglu = non_linearity == "swiglu" # swiglu hidden dim factor multiplier (same #params as relu / gelu) if self.swiglu: hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier hidden_dim = int(ffn_dim_multiplier * hidden_dim) # round hidden dimension to `multiple_of` hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) # layers self.w1 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias) self.w2 = torch.nn.Linear(hidden_dim, dim, bias=fc_bias) if self.swiglu: self.w3 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias) # non-linearity self.non_linearity = get_nonlinearity(non_linearity) def forward( self, x, ): hidden1 = self.w1(x) if self.swiglu: hidden3 = self.w3(x) hidden = F.silu(hidden1) * hidden3 else: hidden = self.non_linearity(hidden1) hidden = F.dropout(hidden, p=self.dropout, training=self.training) return self.w2(hidden) class TimestepEmbedder(torch.nn.Module): def __init__( self, dim: int, frequency_embedding_dim: int, non_linearity: str, dropout: float, fc_bias: bool, max_period: int = 10000, ): super().__init__() self.frequency_embedding_size = frequency_embedding_dim self.projection = ProjectionLayer( in_dim=frequency_embedding_dim, out_dim=dim, non_linearity=non_linearity, dropout=dropout, fc_bias=fc_bias, ) half = frequency_embedding_dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ) self.register_buffer("freqs", freqs, persistent=False) def timestep_embedding(self, t, dim): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py self.freqs = self.freqs.to(device=t.device) args = t[:, None].float() * self.freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding.to(t) def forward(self, t): x = self.timestep_embedding(t, self.frequency_embedding_size) return self.projection(x) class ContextEmbedder(torch.nn.Module): def __init__( self, in_dim: int, out_dim: int, non_linearity: str, dropout: float, fc_bias: bool, norm_eps: float = 1e-5, context_norm: bool = False, ): super().__init__() self.context_norm = context_norm if context_norm: self.norm = RMSNorm(in_dim, norm_eps) self.projection = ProjectionLayer( in_dim=in_dim, out_dim=out_dim, non_linearity=non_linearity, dropout=dropout, fc_bias=fc_bias, ) def forward(self, x): if self.context_norm: x = self.norm(x) h = self.projection(x) return h class DiTBlock(torch.nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: Optional[int] = None, dropout: float = 0.0, norm_eps: float = 1e-5, qk_norm: bool = False, fc_bias: bool = False, ffn_exp: int = 1, ffn_dim_multiplier: int = 4, multiple_of: int = 64, non_linearity: str = "silu", no_cross_attention: bool = False, ): super().__init__() assert dim % n_heads == 0 self.n_heads = n_heads self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads self.dim = dim self.dropout = dropout self.head_dim = dim // n_heads assert self.n_heads % self.n_kv_heads == 0 self.attention = Attention( dim=dim, head_dim=self.head_dim, n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, norm_eps=norm_eps, use_qk_norm=qk_norm, fc_bias=fc_bias, ) self.feed_forward = FeedForward( dim=dim, hidden_dim=int(ffn_exp * dim), ffn_dim_multiplier=ffn_dim_multiplier, multiple_of=multiple_of, dropout=dropout, non_linearity=non_linearity, fc_bias=fc_bias, ) self.attention_norm, self.ffn_norm = [RMSNorm(dim, norm_eps) for _ in range(2)] self.cross_attention = None if not no_cross_attention: self.cross_attention = Attention( dim=dim, head_dim=self.head_dim, n_heads=self.n_heads, n_kv_heads=self.n_heads, norm_eps=norm_eps, use_qk_norm=qk_norm, fc_bias=fc_bias, ) self.scale_shift_table = nn.Parameter( torch.randn(6, self.dim) / self.dim**0.5, ) def forward( self, x: torch.Tensor, cross_x: Optional[torch.Tensor], t: torch.Tensor, padding_mask: Optional[torch.Tensor], memory_padding_mask: Optional[torch.Tensor], rope: Optional[RotaryEmbedding] = None, ): biases = self.scale_shift_table[None] + t.reshape(x.size(0), 6, -1) ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, ) = biases.chunk(6, dim=1) assert self.attention is not None and self.attention_norm is not None h_attn = self.attention( modulate(self.attention_norm(x), shift_msa, scale_msa), key_padding_mask=padding_mask, rope=rope, ) h = x + gate(h_attn, gate_msa) if self.cross_attention is not None: h_cross = self.cross_attention( x=h, cross_x=cross_x, key_padding_mask=memory_padding_mask, ) h = h + h_cross # residual h_ff = self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) out = h + gate(h_ff, gate_mlp) return out class DiT(torch.nn.Module): def __init__(self, config: TransformerConfig): super().__init__() self.dropout = config.dropout if config.in_channels is not None: self.data_proj = torch.nn.Linear(config.in_channels, config.dim) # embeddings self.rope_embeddings = None # rotary embeddings if config.use_rope: self.rope_embeddings = RotaryEmbedding( theta=max(10000, 2 * config.max_positions), head_dim=config.dim // config.n_heads, max_seqlen=config.max_positions, ) self.rope_embeddings.reset_parameters() # transformer blocks self.layers = nn.ModuleList() for _ in range(config.n_layers): self.layers.append( DiTBlock( dim=config.dim, n_heads=config.n_heads, dropout=config.dropout, norm_eps=config.norm_eps, qk_norm=config.qk_norm, fc_bias=config.fc_bias, ffn_exp=config.ffn_exp, ffn_dim_multiplier=config.ffn_dim_multiplier, multiple_of=config.multiple_of, non_linearity=config.non_linearity, ) ) self.norm = RMSNorm(config.dim, config.norm_eps) # output layer self.output = torch.nn.Linear( config.dim, config.out_channels, bias=config.fc_bias ) self.x_embedder = Patcher( in_channels=config.dim, out_channels=config.dim, patch_size=1, ) self.y_embedder = ContextEmbedder( in_dim=config.context_dim, out_dim=config.dim, non_linearity=config.context_non_linearity, dropout=config.context_embedder_dropout, fc_bias=config.fc_bias, norm_eps=config.norm_eps, context_norm=config.context_norm, ) self.t_embedder = TimestepEmbedder( config.dim, config.frequency_embedding_dim, non_linearity=config.timestep_non_linearity, dropout=config.dropout, fc_bias=config.fc_bias, max_period=10000, ) self.t_block_non_linearity = get_nonlinearity(config.t_block_non_linearity) self.t_block = torch.nn.Linear( config.dim, config.dim * 6, bias=config.t_block_bias, ) self.final_layer_scale_shift_table = nn.Parameter( torch.randn(2, config.dim) / config.dim**0.5, ) def forward( self, x: torch.Tensor, time: torch.Tensor, *, padding_mask: Optional[torch.Tensor] = None, memory: Optional[torch.Tensor] = None, memory_padding_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: x = rearrange(x, "b l c-> b c l") h = self.x_embedder(x) h = rearrange(h, "b c l -> b l c") original_N = h.shape[1] N = h.shape[1] h = F.dropout(h, p=self.dropout, training=self.training) t = self.t_embedder(time) # B -> B D t0 = self.t_block_non_linearity(t) t0 = self.t_block(t0) # B D -> B 6D y = self.y_embedder(memory) for layer in self.layers: h = layer( x=h, cross_x=y, t=t0, padding_mask=padding_mask, memory_padding_mask=memory_padding_mask, rope=self.rope_embeddings, ) shift, scale = (self.final_layer_scale_shift_table[None] + t[:, None]).chunk( 2, dim=1 ) # output layer if self.norm is not None: h = self.norm(h) h = modulate(h, shift, scale) h = F.dropout(h, p=self.dropout, training=self.training) output = self.output(h) N = output.shape[1] if original_N != N: output = output[:, -original_N:] return output