import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: var = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(var + self.eps) return self.weight * x class RotaryEmbedding(nn.Module): def __init__(self, dim: int, base: int = 10000): super().__init__() self.dim = dim self.base = base self._cache = {} def get_sin_cos(self, seq_len: int, device, dtype): key = (seq_len, device, dtype) cached = self._cache.get(key, None) if cached is not None and cached[0].device == device: return cached inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=dtype) / self.dim) ) t = torch.arange(seq_len, device=device, dtype=dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) sin = freqs.sin() cos = freqs.cos() self._cache[key] = (sin, cos) return sin, cos def apply_rotary( self, x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor ) -> torch.Tensor: x1, x2 = x[..., : self.dim // 2], x[..., self.dim // 2 : self.dim] # Interleave sin/cos across pairs x_rot = torch.stack((-x2, x1), dim=-1).reshape_as(x[..., : self.dim]) return (x[..., : self.dim] * cos.unsqueeze(-1)).reshape_as( x[..., : self.dim] ) + (x_rot * sin.unsqueeze(-1)).reshape_as(x[..., : self.dim]) class LlamaAttention(nn.Module): def __init__( self, dim: int, n_heads: int, head_dim: int, bias: bool = False, dropout: float = 0.0, rope_dim: Optional[int] = None, cross_attention_dim: Optional[int] = None, use_sdpa: bool = True, ): super().__init__() self.dim = dim self.n_heads = n_heads self.head_dim = head_dim self.inner_dim = n_heads * head_dim self.cross_attention_dim = cross_attention_dim self.q_proj = nn.Linear(dim, self.inner_dim, bias=bias) k_in = dim if cross_attention_dim is None else cross_attention_dim self.k_proj = nn.Linear(k_in, self.inner_dim, bias=bias) self.v_proj = nn.Linear(k_in, self.inner_dim, bias=bias) self.o_proj = nn.Linear(self.inner_dim, dim, bias=bias) self.dropout = dropout self.rope_dim = rope_dim if rope_dim is not None else head_dim self.rope = RotaryEmbedding(self.rope_dim) self.use_sdpa = use_sdpa self._has_sdpa = hasattr(F, "scaled_dot_product_attention") def _shape(self, x: torch.Tensor, b: int, t: int) -> torch.Tensor: return x.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) def forward( self, x: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: b, t, c = x.shape q = self._shape(self.q_proj(x), b, t) if encoder_hidden_states is None: k = self._shape(self.k_proj(x), b, t) v = self._shape(self.v_proj(x), b, t) else: bt, tk, ck = encoder_hidden_states.shape k = self._shape(self.k_proj(encoder_hidden_states), b, tk) v = self._shape(self.v_proj(encoder_hidden_states), b, tk) # RoPE on first rope_dim of head_dim rope_dim = min(self.rope_dim, self.head_dim) seq_len_for_rope = k.shape[-2] sin, cos = self.rope.get_sin_cos( seq_len_for_rope, device=x.device, dtype=x.dtype ) def apply_rope_vec(tensor): head = tensor[..., :rope_dim] tail = tensor[..., rope_dim:] b, h, tt, _ = head.shape head = head.view(b, h, tt, rope_dim // 2, 2) sin_ = sin.view(1, 1, tt, rope_dim // 2, 1) cos_ = cos.view(1, 1, tt, rope_dim // 2, 1) x1 = head[..., 0:1] x2 = head[..., 1:2] rot = torch.cat( [x1 * cos_ - x2 * sin_, x1 * sin_ + x2 * cos_], dim=-1 ).view(b, h, tt, rope_dim) return torch.cat([rot, tail], dim=-1) q = apply_rope_vec(q) k = apply_rope_vec(k) # Prefer PyTorch SDPA (can enable FlashAttention kernel on supported GPUs) if self.use_sdpa and self._has_sdpa: s = k.shape[-2] attn_mask_sdpa = None if attention_mask is not None: m = attention_mask if m.dim() == 2 and m.shape == (b, s): # [b, s] m = m[:, None, None, :] # [b,1,1,s] elif m.dim() == 3 and m.shape[-2] == 1: # [b,1,s] m = m[:, None, :, :] # [b,1,1,s] elif m.dim() == 3 and m.shape[-2] == t: # [b,t,s] m = m[:, None, :, :] # [b,1,t,s] elif m.dim() == 4 and m.shape[1] == 1: # [b,1,t,s] or [b,1,1,s] pass attn_mask_sdpa = m out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask_sdpa, dropout_p=self.dropout if self.training else 0.0, is_causal=False, ) out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim) return self.o_proj(out) else: attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt( self.head_dim ) if attention_mask is not None: attn_scores = attn_scores + attention_mask attn = attn_scores.softmax(dim=-1) attn = F.dropout(attn, p=self.dropout, training=self.training) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim) return self.o_proj(out) class LlamaMLP(nn.Module): def __init__( self, dim: int, hidden_dim: Optional[int] = None, multiple_of: int = 256, dropout: float = 0.0, ): super().__init__() hidden_dim = hidden_dim or 4 * dim # align to multiple_of like Llama hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.gate = nn.Linear(dim, hidden_dim, bias=False) self.up = nn.Linear(dim, hidden_dim, bias=False) self.down = nn.Linear(hidden_dim, dim, bias=False) self.dropout = dropout def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.silu(self.gate(x)) * self.up(x) x = F.dropout(x, p=self.dropout, training=self.training) return self.down(x) class LlamaTransformerBlock(nn.Module): def __init__( self, dim: int, n_heads: int, head_dim: int, mlp_multiple_of: int = 256, dropout: float = 0.0, attention_bias: bool = False, cross_attention_dim: Optional[int] = None, use_ada_layer_norm_single: bool = False, ): super().__init__() self.attn_norm = RMSNorm(dim, 1e-6) self.attn = LlamaAttention( dim, n_heads, head_dim, bias=attention_bias, dropout=dropout, rope_dim=head_dim, cross_attention_dim=None, ) self.cross_attn = None if cross_attention_dim is not None: self.cross_attn_norm = RMSNorm(dim, 1e-6) self.cross_attn = LlamaAttention( dim, n_heads, head_dim, bias=attention_bias, dropout=dropout, rope_dim=head_dim, cross_attention_dim=cross_attention_dim, ) self.mlp_norm = RMSNorm(dim, 1e-6) self.mlp = LlamaMLP(dim, multiple_of=mlp_multiple_of, dropout=dropout) self.use_ada_layer_norm_single = use_ada_layer_norm_single if self.use_ada_layer_norm_single: self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) def forward( self, x: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.use_ada_layer_norm_single: batch_size = x.shape[0] # timestep: [B, 6*D] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) ).chunk(6, dim=1) # Self-Attention with modulation and gating norm_hidden_states = self.attn_norm(x) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa h = self.attn(norm_hidden_states, attention_mask=attention_mask) h = gate_msa * h x = x + h # MLP with modulation and gating norm_hidden_states = self.mlp_norm(x) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp h = self.mlp(norm_hidden_states) h = gate_mlp * h x = x + h return x else: h = self.attn(self.attn_norm(x), attention_mask=attention_mask) x = x + h h = self.mlp(self.mlp_norm(x)) x = x + h return x class ProjectLayer(nn.Module): def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0.0): super().__init__() self.kernel_size = kernel_size self.dropout = dropout self.ffn_1 = nn.Conv1d( hidden_size, filter_size, kernel_size, padding=kernel_size // 2 ) self.ffn_2 = nn.Linear(filter_size, filter_size) def forward(self, x): x = self.ffn_1(x.transpose(1, 2)).transpose(1, 2) x = x * self.kernel_size**-0.5 x = self.ffn_2(x) return x class LlamaTransformer(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, in_channels: int, out_channels: int, num_layers: int = 12, num_layers_2: int = 2, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, norm_type: str = "layer_norm", ): super().__init__() inner_dim = num_attention_heads * attention_head_dim inner_dim_2 = inner_dim * 2 self.in_channels = in_channels self.out_channels = out_channels self.inner_dim = inner_dim self.inner_dim_2 = inner_dim_2 self.dropout = dropout self.proj_in = ProjectLayer(in_channels, inner_dim, kernel_size=3) use_ada_single = norm_type == "ada_norm_single" self.transformer_blocks = nn.ModuleList( [ LlamaTransformerBlock( dim=inner_dim, n_heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, attention_bias=False, cross_attention_dim=cross_attention_dim, use_ada_layer_norm_single=use_ada_single, ) for _ in range(num_layers) ] ) self.transformer_blocks_2 = nn.ModuleList( [ LlamaTransformerBlock( dim=inner_dim_2, n_heads=num_attention_heads, head_dim=attention_head_dim * 2, dropout=dropout, attention_bias=False, cross_attention_dim=cross_attention_dim, use_ada_layer_norm_single=use_ada_single, ) for _ in range(num_layers_2) ] ) self.connection_proj = ProjectLayer( in_channels + inner_dim, inner_dim_2, kernel_size=3 ) self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) self.norm_out_2 = nn.LayerNorm(inner_dim_2, elementwise_affine=False, eps=1e-6) self.scale_shift_table = nn.Parameter( torch.randn(2, inner_dim) / inner_dim**0.5 ) self.scale_shift_table_2 = nn.Parameter( torch.randn(2, inner_dim_2) / inner_dim_2**0.5 ) self.proj_out = ProjectLayer(inner_dim_2, out_channels, kernel_size=3) self.adaln_single = AdaLayerNormSingleFlow(inner_dim) self.adaln_single_2 = AdaLayerNormSingleFlow(inner_dim_2) def forward( self, hidden_states: torch.Tensor, timestep: Optional[torch.LongTensor] = None, ): s = self.proj_in(hidden_states) embedded_timestep = None timestep_mod = None if self.adaln_single is not None and timestep is not None: batch_size = s.shape[0] timestep_mod, embedded_timestep = self.adaln_single( timestep, hidden_dtype=s.dtype ) for blk in self.transformer_blocks: s = blk(s, timestep=timestep_mod) if embedded_timestep is None: embedded_timestep = torch.zeros( s.size(0), s.size(-1), device=s.device, dtype=s.dtype ) shift, scale = ( self.scale_shift_table[None] + embedded_timestep[:, None] ).chunk(2, dim=1) s = self.norm_out(s) s = s * (1 + scale) + shift x = torch.cat([hidden_states, s], dim=-1) x = self.connection_proj(x) embedded_timestep_2 = None timestep_mod_2 = None if self.adaln_single_2 is not None and timestep is not None: batch_size = x.shape[0] timestep_mod_2, embedded_timestep_2 = self.adaln_single_2( timestep, hidden_dtype=x.dtype ) for blk in self.transformer_blocks_2: x = blk(x, timestep=timestep_mod_2) if embedded_timestep_2 is None: embedded_timestep_2 = torch.zeros( x.size(0), x.size(-1), device=x.device, dtype=x.dtype ) shift_2, scale_2 = ( self.scale_shift_table_2[None] + embedded_timestep_2[:, None] ).chunk(2, dim=1) x = self.norm_out_2(x) x = x * (1 + scale_2) + shift_2 out = self.proj_out(x) return out class PixArtAlphaCombinedFlowEmbeddings(nn.Module): def __init__(self, embedding_dim: int, size_emb_dim: int): super().__init__() self.flow_t_size = 512 self.outdim = size_emb_dim self.timestep_embedder = TimestepEmbedding( in_channels=self.flow_t_size, time_embed_dim=embedding_dim ) def timestep_embedding(self, timesteps, max_period=10000, scale=1000): half = self.flow_t_size // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half ).type(timesteps.type()) args = timesteps[:, None] * freqs[None] * scale embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if self.flow_t_size % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def forward(self, timestep, hidden_dtype): timesteps_proj = self.timestep_embedding(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) conditioning = timesteps_emb return conditioning class AdaLayerNormSingleFlow(nn.Module): def __init__(self, embedding_dim: int): super().__init__() self.emb = PixArtAlphaCombinedFlowEmbeddings( embedding_dim, size_emb_dim=embedding_dim // 3 ) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) def forward( self, timestep: torch.Tensor, hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep class TimestepEmbedding(nn.Module): def __init__(self, in_channels: int, time_embed_dim: int): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim) self.act = nn.SiLU() self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear_1(x) x = self.act(x) x = self.linear_2(x) return x class Timesteps(nn.Module): def __init__( self, num_channels: int, flip_sin_to_cos: bool = True, downscale_freq_shift: float = 0, ): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift def forward(self, timesteps: torch.Tensor) -> torch.Tensor: half_dim = self.num_channels // 2 exponent = ( -math.log(10000) * torch.arange(0, half_dim, device=timesteps.device) / (half_dim - self.downscale_freq_shift) ) emb = torch.exp(exponent)[None, :] * timesteps[:, None] if self.flip_sin_to_cos: emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) else: emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) if self.num_channels % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1)) return emb