"""FOFPred Transformer modified from OmniGen2 DiT.""" import importlib.util import itertools import math import sys import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models.activations import get_activation from diffusers.models.attention_processor import Attention from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import ( USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers, ) from einops import rearrange, repeat # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata def _is_package_available(pkg_name: str): pkg_exists = importlib.util.find_spec(pkg_name) is not None pkg_version = "N/A" if pkg_exists: try: pkg_version = importlib_metadata.version(pkg_name) except (ImportError, importlib_metadata.PackageNotFoundError): pkg_exists = False return pkg_exists, pkg_version _triton_available, _triton_version = _is_package_available("triton") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") def is_triton_available(): return _triton_available def is_flash_attn_available(): return _flash_attn_available if is_triton_available(): import triton import triton.language as tl def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): def decorator(*args, **kwargs): if cuda_amp_deprecated: kwargs["device_type"] = "cuda" return dec(*args, **kwargs) return decorator if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] deprecated = True from torch.amp import custom_bwd, custom_fwd # type: ignore[attr-defined] else: deprecated = False from torch.cuda.amp import custom_bwd, custom_fwd custom_fwd = custom_amp_decorator(custom_fwd, deprecated) custom_bwd = custom_amp_decorator(custom_bwd, deprecated) def triton_autotune_configs(): # Return configs with a valid warp count for the current device configs = [] # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 max_threads_per_block = 1024 # Default to warp size 32 if not defined by device warp_size = getattr( torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32, ) # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit warp_count = 1 while warp_count * warp_size <= max_threads_per_block: configs.append(triton.Config({}, num_warps=warp_count)) warp_count *= 2 return configs def layer_norm_ref( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, zero_centered_weight=False, dropout_mask=None, dropout_mask1=None, upcast=False, ): dtype = x.dtype if upcast: x = x.float() weight = weight.float() bias = bias.float() if bias is not None else None residual = residual.float() if residual is not None else residual x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None if zero_centered_weight: weight = weight + 1.0 if weight1 is not None: weight1 = weight1 + 1.0 if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: x = x * rowscale[..., None] if dropout_p > 0.0: if dropout_mask is not None: x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) else: x = F.dropout(x, p=dropout_p) if x1 is not None: if dropout_mask1 is not None: x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) else: x1 = F.dropout(x1, p=dropout_p) if x1 is not None: x = x + x1 if residual is not None: x = (x + residual).to(x.dtype) out = F.layer_norm( x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps ).to(dtype) if weight1 is None: return out if not prenorm else (out, x) else: out1 = F.layer_norm( x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps ).to(dtype) return (out, out1) if not prenorm else (out, out1, x) def rms_norm_ref( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, zero_centered_weight=False, dropout_mask=None, dropout_mask1=None, upcast=False, ): dtype = x.dtype if upcast: x = x.float() weight = weight.float() bias = bias.float() if bias is not None else None residual = residual.float() if residual is not None else residual x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None if zero_centered_weight: weight = weight + 1.0 if weight1 is not None: weight1 = weight1 + 1.0 if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: x = x * rowscale[..., None] if dropout_p > 0.0: if dropout_mask is not None: x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) else: x = F.dropout(x, p=dropout_p) if x1 is not None: if dropout_mask1 is not None: x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) else: x1 = F.dropout(x1, p=dropout_p) if x1 is not None: x = x + x1 if residual is not None: x = (x + residual).to(x.dtype) rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) out = ( (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) ).to(dtype) if weight1 is None: return out if not prenorm else (out, x) else: out1 = ( (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1) ).to(dtype) return (out, out1) if not prenorm else (out, out1, x) @triton.autotune( configs=triton_autotune_configs(), key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights B, # pointer to the biases RESIDUAL, # pointer to the residual X1, W1, B1, Y1, RESIDUAL_OUT, # pointer to the residual ROWSCALE, SEEDS, # Dropout seeds for each row DROPOUT_MASK, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_res_row, stride_res_out_row, stride_x1_row, stride_y1_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, # Dropout probability zero_centered_weight, # If true, add 1.0 to the weight IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, STORE_DROPOUT_MASK: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_X1: tl.constexpr, HAS_W1: tl.constexpr, HAS_B1: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) X += row * stride_x_row Y += row * stride_y_row if HAS_RESIDUAL: RESIDUAL += row * stride_res_row if STORE_RESIDUAL_OUT: RESIDUAL_OUT += row * stride_res_out_row if HAS_X1: X1 += row * stride_x1_row if HAS_W1: Y1 += row * stride_y1_row # Compute mean and variance cols = tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) x *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = ( tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) if HAS_X1: x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) x1 *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = ( tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store( DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N ) x += x1 if HAS_RESIDUAL: residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) x += residual if STORE_RESIDUAL_OUT: tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) if zero_centered_weight: w += 1.0 if HAS_BIAS: b = tl.load(B + cols, mask=mask).to(tl.float32) x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd y = x_hat * w + b if HAS_BIAS else x_hat * w # Write output tl.store(Y + cols, y, mask=mask) if HAS_W1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) if zero_centered_weight: w1 += 1.0 if HAS_B1: b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 tl.store(Y1 + cols, y1, mask=mask) def _layer_norm_fwd( x, weight, bias, eps, residual=None, x1=None, weight1=None, bias1=None, dropout_p=0.0, rowscale=None, out_dtype=None, residual_dtype=None, zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, out=None, residual_out=None, ): if residual is not None: residual_dtype = residual.dtype M, N = x.shape assert x.stride(-1) == 1 if residual is not None: assert residual.stride(-1) == 1 assert residual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if x1 is not None: assert x1.shape == x.shape assert rowscale is None assert x1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output if out is None: out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) else: assert out.shape == x.shape assert out.stride(-1) == 1 if weight1 is not None: y1 = torch.empty_like(out) assert y1.stride(-1) == 1 else: y1 = None if ( residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None ): if residual_out is None: residual_out = torch.empty( M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype, ) else: assert residual_out.shape == x.shape assert residual_out.stride(-1) == 1 else: residual_out = None mean = ( torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None ) rstd = torch.empty((M,), dtype=torch.float32, device=x.device) if dropout_p > 0.0: seeds = torch.randint( 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 ) else: seeds = None if return_dropout_mask and dropout_p > 0.0: dropout_mask = torch.empty( M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool ) else: dropout_mask = None # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") with torch.cuda.device(x.device.index): _layer_norm_fwd_1pass_kernel[(M,)]( x, out, weight, bias, residual, x1, weight1, bias1, y1, residual_out, rowscale, seeds, dropout_mask, mean, rstd, x.stride(0), out.stride(0), residual.stride(0) if residual is not None else 0, residual_out.stride(0) if residual_out is not None else 0, x1.stride(0) if x1 is not None else 0, y1.stride(0) if y1 is not None else 0, M, N, eps, dropout_p, zero_centered_weight, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None, bias is not None, dropout_p > 0.0, dropout_mask is not None, rowscale is not None, ) # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 if dropout_mask is not None and x1 is not None: dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) else: dropout_mask1 = None return ( out, y1, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask, dropout_mask1, ) @triton.autotune( configs=triton_autotune_configs(), key=[ "N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT", ], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.jit def _layer_norm_bwd_kernel( X, # pointer to the input W, # pointer to the weights B, # pointer to the biases Y, # pointer to the output to be recomputed DY, # pointer to the output gradient DX, # pointer to the input gradient DW, # pointer to the partial sum of weights gradient DB, # pointer to the partial sum of biases gradient DRESIDUAL, W1, DY1, DX1, DW1, DB1, DRESIDUAL_IN, ROWSCALE, SEEDS, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_dy_row, stride_dx_row, stride_dres_row, stride_dy1_row, stride_dx1_row, stride_dres_in_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, zero_centered_weight, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_DY1: tl.constexpr, HAS_DX1: tl.constexpr, HAS_B1: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, ): # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program # Do not early exit if row_start >= M, because we need to write DW and DB cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row if HAS_DRESIDUAL: DRESIDUAL += row_start * stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += row_start * stride_dres_in_row DY += row_start * stride_dy_row DX += row_start * stride_dx_row if HAS_DY1: DY1 += row_start * stride_dy1_row if HAS_DX1: DX1 += row_start * stride_dx1_row if RECOMPUTE_OUTPUT: Y += row_start * stride_y_row w = tl.load(W + cols, mask=mask).to(tl.float32) if zero_centered_weight: w += 1.0 if RECOMPUTE_OUTPUT and HAS_BIAS: b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) if HAS_DY1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) if zero_centered_weight: w1 += 1.0 dw = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_BIAS: db = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_DY1: dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_B1: db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) row_end = min((row_block_id + 1) * rows_per_program, M) for row in range(row_start, row_end): # Load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) if HAS_DY1: dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) if not IS_RMS_NORM: mean = tl.load(Mean + row) rstd = tl.load(Rstd + row) # Compute dx xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd xhat = tl.where(mask, xhat, 0.0) if RECOMPUTE_OUTPUT: y = xhat * w + b if HAS_BIAS else xhat * w tl.store(Y + cols, y, mask=mask) wdy = w * dy dw += dy * xhat if HAS_BIAS: db += dy if HAS_DY1: wdy += w1 * dy1 dw1 += dy1 * xhat if HAS_B1: db1 += dy1 if not IS_RMS_NORM: c1 = tl.sum(xhat * wdy, axis=0) / N c2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat * c1 + c2)) * rstd else: c1 = tl.sum(xhat * wdy, axis=0) / N dx = (wdy - xhat * c1) * rstd if HAS_DRESIDUAL: dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) dx += dres # Write dx if STORE_DRESIDUAL: tl.store(DRESIDUAL_IN + cols, dx, mask=mask) if HAS_DX1: if HAS_DROPOUT: keep_mask = ( tl.rand( tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7 ) > dropout_p ) dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) else: dx1 = dx tl.store(DX1 + cols, dx1, mask=mask) if HAS_DROPOUT: keep_mask = ( tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) dx *= rowscale tl.store(DX + cols, dx, mask=mask) X += stride_x_row if HAS_DRESIDUAL: DRESIDUAL += stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += stride_dres_in_row if RECOMPUTE_OUTPUT: Y += stride_y_row DY += stride_dy_row DX += stride_dx_row if HAS_DY1: DY1 += stride_dy1_row if HAS_DX1: DX1 += stride_dx1_row tl.store(DW + row_block_id * N + cols, dw, mask=mask) if HAS_BIAS: tl.store(DB + row_block_id * N + cols, db, mask=mask) if HAS_DY1: tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) if HAS_B1: tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) def _layer_norm_bwd( dy, x, weight, bias, eps, mean, rstd, dresidual=None, dy1=None, weight1=None, bias1=None, seeds=None, dropout_p=0.0, rowscale=None, has_residual=False, has_x1=False, zero_centered_weight=False, is_rms_norm=False, x_dtype=None, recompute_output=False, ): M, N = x.shape assert x.stride(-1) == 1 assert dy.stride(-1) == 1 assert dy.shape == (M, N) if dresidual is not None: assert dresidual.stride(-1) == 1 assert dresidual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if dy1 is not None: assert weight1 is not None assert dy1.shape == dy.shape assert dy1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if seeds is not None: assert seeds.is_contiguous() assert seeds.shape == (M if not has_x1 else M * 2,) if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output dx = ( torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) ) dresidual_in = ( torch.empty_like(x) if has_residual and ( dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1 ) else None ) dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None y = ( torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None ) if recompute_output: assert weight1 is None, ( "recompute_output is not supported with parallel LayerNorm" ) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the # latency of the gmem reads/writes, but will increase the time of summing up dw / db. sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) _db = ( torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None ) _dw1 = torch.empty_like(_dw) if weight1 is not None else None _db1 = torch.empty_like(_db) if bias1 is not None else None rows_per_program = math.ceil(M / sm_count) grid = (sm_count,) with torch.cuda.device(x.device.index): _layer_norm_bwd_kernel[grid]( x, weight, bias, y, dy, dx, _dw, _db, dresidual, weight1, dy1, dx1, _dw1, _db1, dresidual_in, rowscale, seeds, mean, rstd, x.stride(0), 0 if not recompute_output else y.stride(0), dy.stride(0), dx.stride(0), dresidual.stride(0) if dresidual is not None else 0, dy1.stride(0) if dy1 is not None else 0, dx1.stride(0) if dx1 is not None else 0, dresidual_in.stride(0) if dresidual_in is not None else 0, M, N, eps, dropout_p, zero_centered_weight, rows_per_program, is_rms_norm, BLOCK_N, dresidual is not None, dresidual_in is not None, bias is not None, dropout_p > 0.0, ) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None # Don't need to compute dresidual_in separately in this case if ( has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None ): dresidual_in = dx if has_x1 and dropout_p == 0.0: dx1 = dx return ( (dx, dw, db, dresidual_in, dx1, dw1, db1) if not recompute_output else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) ) class LayerNormFn(torch.autograd.Function): @staticmethod def forward( ctx, x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, out=None, residual_out=None, ): x_shape_og = x.shape # Check for zero sequence length if x.numel() == 0: ctx.zero_seq_length = True # Only save minimal required tensors for backward # ctx.save_for_backward(weight, bias, weight1, bias1) ctx.x_shape_og = x_shape_og ctx.weight_shape = weight.shape ctx.weight_dtype = weight.dtype ctx.weight_device = weight.device ctx.has_bias = bias is not None ctx.bias_shape = bias.shape if bias is not None else None ctx.bias_dtype = bias.dtype if bias is not None else None ctx.bias_device = bias.device if bias is not None else None ctx.has_weight1 = weight1 is not None ctx.weight1_shape = weight1.shape if weight1 is not None else None ctx.weight1_dtype = weight1.dtype if weight1 is not None else None ctx.weight1_device = weight1.device if weight1 is not None else None ctx.has_bias1 = bias1 is not None ctx.bias1_shape = bias1.shape if bias1 is not None else None ctx.bias1_dtype = bias1.dtype if bias1 is not None else None ctx.bias1_device = bias1.device if bias1 is not None else None ctx.has_residual = residual is not None ctx.has_x1 = x1 is not None ctx.dropout_p = dropout_p # Handle output tensors with correct dtype y = x # Preserve input tensor properties y1 = torch.empty_like(x) if x1 is not None else None # Only create residual_out if prenorm is True residual_out = ( torch.empty( x.shape, dtype=torch.float32 if residual_in_fp32 else x.dtype, device=x.device, ) if prenorm else None ) # Handle dropout masks dropout_mask = None dropout_mask1 = None if return_dropout_mask: dropout_mask = torch.empty_like(x, dtype=torch.uint8) if x1 is not None: dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) # Return based on configuration if not return_dropout_mask: if weight1 is None: return y if not prenorm else (y, residual_out) else: return (y, y1) if not prenorm else (y, y1, residual_out) else: if weight1 is None: return ( (y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1) ) else: return ( (y, y1, dropout_mask, dropout_mask1) if not prenorm else (y, y1, residual_out, dropout_mask, dropout_mask1) ) ctx.zero_seq_length = False # reshape input data into 2D tensor x = x.reshape(-1, x.shape[-1]) if x.stride(-1) != 1: x = x.contiguous() if residual is not None: assert residual.shape == x_shape_og residual = residual.reshape(-1, residual.shape[-1]) if residual.stride(-1) != 1: residual = residual.contiguous() if x1 is not None: assert x1.shape == x_shape_og assert rowscale is None, ( "rowscale is not supported with parallel LayerNorm" ) x1 = x1.reshape(-1, x1.shape[-1]) if x1.stride(-1) != 1: x1 = x1.contiguous() weight = weight.contiguous() if bias is not None: bias = bias.contiguous() if weight1 is not None: weight1 = weight1.contiguous() if bias1 is not None: bias1 = bias1.contiguous() if rowscale is not None: rowscale = rowscale.reshape(-1).contiguous() residual_dtype = ( residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) ) if out is not None: out = out.reshape(-1, out.shape[-1]) if residual_out is not None: residual_out = residual_out.reshape(-1, residual_out.shape[-1]) y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = ( _layer_norm_fwd( x, weight, bias, eps, residual, x1, weight1, bias1, dropout_p=dropout_p, rowscale=rowscale, residual_dtype=residual_dtype, zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=return_dropout_mask, out=out, residual_out=residual_out, ) ) ctx.save_for_backward( residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd ) ctx.x_shape_og = x_shape_og ctx.eps = eps ctx.dropout_p = dropout_p ctx.is_rms_norm = is_rms_norm ctx.has_residual = residual is not None ctx.has_x1 = x1 is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype ctx.zero_centered_weight = zero_centered_weight y = y.reshape(x_shape_og) y1 = y1.reshape(x_shape_og) if y1 is not None else None residual_out = ( residual_out.reshape(x_shape_og) if residual_out is not None else None ) dropout_mask = ( dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None ) dropout_mask1 = ( dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None ) if not return_dropout_mask: if weight1 is None: return y if not prenorm else (y, residual_out) else: return (y, y1) if not prenorm else (y, y1, residual_out) else: if weight1 is None: return ( (y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1) ) else: return ( (y, y1, dropout_mask, dropout_mask1) if not prenorm else (y, y1, residual_out, dropout_mask, dropout_mask1) ) @staticmethod def backward(ctx, dy, *args): if ctx.zero_seq_length: return ( torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), torch.zeros( ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device, ), torch.zeros( ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device ) if ctx.has_bias else None, torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None, torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None, torch.zeros( ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device, ) if ctx.has_weight1 else None, torch.zeros( ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device ) if ctx.has_bias1 else None, None, None, None, None, None, None, None, None, None, None, ) x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ( ctx.saved_tensors ) dy = dy.reshape(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() assert dy.shape == x.shape if weight1 is not None: dy1, args = args[0], args[1:] dy1 = dy1.reshape(-1, dy1.shape[-1]) if dy1.stride(-1) != 1: dy1 = dy1.contiguous() assert dy1.shape == x.shape else: dy1 = None if ctx.prenorm: dresidual = args[0] dresidual = dresidual.reshape(-1, dresidual.shape[-1]) if dresidual.stride(-1) != 1: dresidual = dresidual.contiguous() assert dresidual.shape == x.shape else: dresidual = None dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( dy, x, weight, bias, ctx.eps, mean, rstd, dresidual, dy1, weight1, bias1, seeds, ctx.dropout_p, rowscale, ctx.has_residual, ctx.has_x1, ctx.zero_centered_weight, ctx.is_rms_norm, x_dtype=ctx.x_dtype, ) return ( dx.reshape(ctx.x_shape_og), dw, db, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, dw1, db1, None, None, None, None, None, None, None, None, None, None, ) def layer_norm_fn( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, out=None, residual_out=None, ): return LayerNormFn.apply( x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, zero_centered_weight, is_rms_norm, return_dropout_mask, out, residual_out, ) def rms_norm_fn( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, return_dropout_mask=False, out=None, residual_out=None, ): return LayerNormFn.apply( x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, zero_centered_weight, True, return_dropout_mask, out, residual_out, ) class RMSNorm(torch.nn.Module): def __init__( self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps if dropout_p > 0.0: self.drop = torch.nn.Dropout(dropout_p) else: self.drop = None self.zero_centered_weight = zero_centered_weight self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): if not self.zero_centered_weight: torch.nn.init.ones_(self.weight) else: torch.nn.init.zeros_(self.weight) def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): return rms_norm_fn( x, self.weight, self.bias, residual=residual, eps=self.eps, dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, prenorm=prenorm, residual_in_fp32=residual_in_fp32, zero_centered_weight=self.zero_centered_weight, ) else: from torch.nn import RMSNorm if is_flash_attn_available(): from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input from flash_attn.ops.activations import swiglu else: def swiglu(x, y): return F.silu(x.float(), inplace=False).to(x.dtype) * y warnings.warn( "Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance" ) @dataclass class TeaCacheParams: """ TeaCache parameters for `OmniGen2Transformer3DModel` See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding Args: previous_residual (Optional[torch.Tensor]): The tensor difference between the output and the input of the transformer layers from the previous timestep. previous_modulated_inp (Optional[torch.Tensor]): The modulated input from the previous timestep used to indicate the change of the transformer layer's output. accumulated_rel_l1_distance (float): The accumulated relative L1 distance. is_first_or_last_step (bool): Whether the current timestep is the first or last step. """ previous_residual: Optional[torch.Tensor] = None previous_modulated_inp: Optional[torch.Tensor] = None accumulated_rel_l1_distance: float = 0 is_first_or_last_step: bool = False class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) self.initialize_weights() def initialize_weights(self): nn.init.normal_(self.linear_1.weight, std=0.02) nn.init.zeros_(self.linear_1.bias) nn.init.normal_(self.linear_2.weight, std=0.02) nn.init.zeros_(self.linear_2.bias) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample def apply_rotary_emb( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (`torch.Tensor`): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ if use_real: cos, sin = freqs_cis # [S, D] cos = cos[None, None] sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Used for flux, cogvideox, hunyuan-dit x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind( -1 ) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: # Used for Stable Audio, OmniGen and CogView4 x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind( -2 ) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError( f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2." ) out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: # used for lumina # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_rotated = torch.view_as_complex( x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2) ) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) class OmniGen2AttnProcessorFlash2Varlen: """ Processor for implementing scaled dot-product attention with flash attention and variable length sequences. This processor implements: - Flash attention with variable length sequences - Rotary position embeddings (RoPE) - Query-Key normalization - Proportional attention scaling Args: None """ def __init__(self) -> None: """Initialize the attention processor.""" if not is_flash_attn_available(): raise ImportError( "OmniGen2AttnProcessorFlash2Varlen requires flash_attn. " "Please install flash_attn." ) def _upad_input( self, query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: torch.Tensor, query_length: int, num_heads: int, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int], ]: """ Unpad the input tensors for flash attention. Args: query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) attention_mask: Attention mask tensor of shape (batch_size, seq_len) query_length: Length of the query sequence num_heads: Number of attention heads Returns: Tuple containing: - Unpadded query tensor - Unpadded key tensor - Unpadded value tensor - Query indices - Tuple of cumulative sequence lengths for query and key - Tuple of maximum sequence lengths for query and key """ def _get_unpad_data( attention_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, int]: """Helper function to get unpadding data from attention mask.""" seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) ) return indices, cu_seqlens, max_seqlen_in_batch indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape # Unpad key and value layers key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) # Handle different query length cases if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( query_layer, attention_mask ) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, base_sequence_length: Optional[int] = None, ) -> torch.Tensor: """ Process attention computation with flash attention. Args: attn: Attention module hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) encoder_hidden_states: Encoder hidden states tensor attention_mask: Optional attention mask tensor image_rotary_emb: Optional rotary embeddings for image tokens base_sequence_length: Optional base sequence length for proportional attention Returns: torch.Tensor: Processed hidden states after attention computation """ batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query_dim = query.shape[-1] inner_dim = key.shape[-1] head_dim = query_dim // attn.heads dtype = query.dtype # Get key-value heads kv_heads = inner_dim // head_dim # Reshape tensors for attention computation query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, kv_heads, head_dim) value = value.view(batch_size, -1, kv_heads, head_dim) # Apply Query-Key normalization if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply Rotary Position Embeddings if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, use_real=False) key = apply_rotary_emb(key, image_rotary_emb, use_real=False) query, key = query.to(dtype), key.to(dtype) # Calculate attention scale if base_sequence_length is not None: softmax_scale = ( math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale ) else: softmax_scale = attn.scale # Unpad input for flash attention ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens, ) = self._upad_input( query, key, value, attention_mask, sequence_length, attn.heads ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens # Handle different number of heads if kv_heads < attn.heads: key_states = repeat( key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads ) value_states = repeat( value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads ) # Apply flash attention attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=0.0, causal=False, softmax_scale=softmax_scale, ) # Pad output and apply final transformations hidden_states = pad_input( attn_output_unpad, indices_q, batch_size, sequence_length ) hidden_states = hidden_states.flatten(-2) hidden_states = hidden_states.type_as(query) # Apply output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states class OmniGen2AttnProcessor: """ Processor for implementing scaled dot-product attention with flash attention and variable length sequences. This processor is optimized for PyTorch 2.0 and implements: - Flash attention with variable length sequences - Rotary position embeddings (RoPE) - Query-Key normalization - Proportional attention scaling Args: None Raises: ImportError: If PyTorch version is less than 2.0 """ def __init__(self) -> None: """Initialize the attention processor.""" if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " "Please upgrade PyTorch to version 2.0 or later." ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, base_sequence_length: Optional[int] = None, ) -> torch.Tensor: """ Process attention computation with flash attention. Args: attn: Attention module hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) encoder_hidden_states: Encoder hidden states tensor attention_mask: Optional attention mask tensor image_rotary_emb: Optional rotary embeddings for image tokens base_sequence_length: Optional base sequence length for proportional attention Returns: torch.Tensor: Processed hidden states after attention computation """ batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query_dim = query.shape[-1] inner_dim = key.shape[-1] head_dim = query_dim // attn.heads dtype = query.dtype # Get key-value heads kv_heads = inner_dim // head_dim # Reshape tensors for attention computation query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, kv_heads, head_dim) value = value.view(batch_size, -1, kv_heads, head_dim) # Apply Query-Key normalization if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply Rotary Position Embeddings if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, use_real=False) key = apply_rotary_emb(key, image_rotary_emb, use_real=False) query, key = query.to(dtype), key.to(dtype) # Calculate attention scale if base_sequence_length is not None: softmax_scale = ( math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale ) else: softmax_scale = attn.scale # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) if attention_mask is not None: attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6 key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, scale=softmax_scale ) hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) hidden_states = hidden_states.type_as(query) # Apply output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states class LuminaRMSNormZero(nn.Module): """ Norm layer adaptive RMS normalization zero. Parameters: embedding_dim (`int`): The size of each embedding vector. """ def __init__( self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool, ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear( min(embedding_dim, 1024), 4 * embedding_dim, bias=True, ) self.norm = RMSNorm(embedding_dim, eps=norm_eps) def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) return x, gate_msa, scale_mlp, gate_mlp class LuminaLayerNormContinuous(nn.Module): def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters # because the output is immediately scaled and shifted by the projected conditioning embeddings. # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. # However, this is how it was implemented in the original code, and it's rather likely you should # set `elementwise_affine` to False. elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", out_dim: Optional[int] = None, ): super().__init__() # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) elif norm_type == "rms_norm": self.norm = RMSNorm( embedding_dim, eps=eps, elementwise_affine=elementwise_affine ) else: raise ValueError(f"unknown norm_type {norm_type}") self.linear_2 = None if out_dim is not None: self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) def forward( self, x: torch.Tensor, conditioning_embedding: torch.Tensor, ) -> torch.Tensor: # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) scale = emb x = self.norm(x) * (1 + scale)[:, None, :] if self.linear_2 is not None: x = self.linear_2(x) return x class LuminaFeedForward(nn.Module): r""" A feed-forward layer. Parameters: hidden_size (`int`): The dimensionality of the hidden layers in the model. This parameter determines the width of the model's hidden representations. intermediate_size (`int`): The intermediate dimension of the feedforward layer. multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple of this value. ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden dimension. Defaults to None. """ def __init__( self, dim: int, inner_dim: int, multiple_of: Optional[int] = 256, ffn_dim_multiplier: Optional[float] = None, ): super().__init__() self.swiglu = swiglu # custom hidden_size factor multiplier if ffn_dim_multiplier is not None: inner_dim = int(ffn_dim_multiplier * inner_dim) inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) self.linear_1 = nn.Linear( dim, inner_dim, bias=False, ) self.linear_2 = nn.Linear( inner_dim, dim, bias=False, ) self.linear_3 = nn.Linear( dim, inner_dim, bias=False, ) def forward(self, x): h1, h2 = self.linear_1(x), self.linear_3(x) return self.linear_2(self.swiglu(h1, h2)) class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__( self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, ) -> None: super().__init__() self.time_proj = Timesteps( num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale, ) self.timestep_embedder = TimestepEmbedding( in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) ) self.caption_embedder = nn.Sequential( RMSNorm(text_feat_dim, eps=norm_eps), nn.Linear(text_feat_dim, hidden_size, bias=True), ) self._initialize_weights() def _initialize_weights(self): nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) nn.init.zeros_(self.caption_embedder[1].bias) def forward( self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: timestep_proj = self.time_proj(timestep).to(dtype=dtype) time_embed = self.timestep_embedder(timestep_proj) caption_embed = self.caption_embedder(text_hidden_states) return time_embed, caption_embed class OmniGen2RotaryPosEmbed(nn.Module): def __init__( self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2, ): super().__init__() self.theta = theta self.axes_dim = axes_dim self.axes_lens = axes_lens self.patch_size = patch_size @staticmethod def get_freqs_cis( axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int ) -> List[torch.Tensor]: freqs_cis = [] freqs_dtype = ( torch.float32 if torch.backends.mps.is_available() else torch.float64 ) for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) freqs_cis.append(emb) return freqs_cis def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: device = ids.device if ids.device.type == "mps": ids = ids.to("cpu") result = [] for i in range(len(self.axes_dim)): freqs = freqs_cis[i].to(ids.device) index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) result.append( torch.gather( freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index ) ) return torch.cat(result, dim=-1).to(device) def forward( self, freqs_cis, attention_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device, ): batch_size = len(attention_mask) p = self.patch_size encoder_seq_len = attention_mask.shape[1] l_effective_cap_len = attention_mask.sum(dim=1).tolist() if isinstance(l_effective_img_len[0], list): # Check for t-dim case seq_lengths = [ cap_len + sum(ref_img_len) + sum(img_len) for cap_len, ref_img_len, img_len in zip( l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len ) ] else: # Original case seq_lengths = [ cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip( l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len ) ] max_seq_len = max(seq_lengths) max_ref_img_len = max( [sum(ref_img_len) for ref_img_len in l_effective_ref_img_len] ) if isinstance(l_effective_img_len[0], list): max_img_len = max([sum(ln) for ln in l_effective_img_len]) else: max_img_len = max(l_effective_img_len) # Create position IDs position_ids = torch.zeros( batch_size, max_seq_len, 3, dtype=torch.int32, device=device ) for i, (cap_seq_len, seq_len) in enumerate( zip(l_effective_cap_len, seq_lengths) ): # add text position ids position_ids[i, :cap_seq_len] = repeat( torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3" ) pe_shift = cap_seq_len pe_shift_len = cap_seq_len if ref_img_sizes[i] is not None: for ref_img_size, ref_img_len in zip( ref_img_sizes[i], l_effective_ref_img_len[i] ): H, W = ref_img_size ref_H_tokens, ref_W_tokens = H // p, W // p assert ref_H_tokens * ref_W_tokens == ref_img_len # add image position ids row_ids = repeat( torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens, ).flatten() col_ids = repeat( torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens, ).flatten() position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 0] = ( pe_shift ) position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 1] = ( row_ids ) position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 2] = ( col_ids ) pe_shift += max(ref_H_tokens, ref_W_tokens) pe_shift_len += ref_img_len if isinstance(l_effective_img_len[i], list): # New case for img_size, img_len in zip(img_sizes[i], l_effective_img_len[i]): H, W = img_size H_tokens, W_tokens = H // p, W // p assert H_tokens * W_tokens == img_len row_ids = repeat( torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens, ).flatten() col_ids = repeat( torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens, ).flatten() end_idx = pe_shift_len + img_len position_ids[i, pe_shift_len:end_idx, 0] = pe_shift position_ids[i, pe_shift_len:end_idx, 1] = row_ids position_ids[i, pe_shift_len:end_idx, 2] = col_ids pe_shift += max(H_tokens, W_tokens) pe_shift_len = end_idx else: # Original case H, W = img_sizes[i] H_tokens, W_tokens = H // p, W // p assert H_tokens * W_tokens == l_effective_img_len[i] row_ids = repeat( torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens, ).flatten() col_ids = repeat( torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens, ).flatten() assert pe_shift_len + l_effective_img_len[i] == seq_len position_ids[i, pe_shift_len:seq_len, 0] = pe_shift position_ids[i, pe_shift_len:seq_len, 1] = row_ids position_ids[i, pe_shift_len:seq_len, 2] = col_ids # Get combined rotary embeddings freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) # create separate rotary embeddings for captions and images cap_freqs_cis = torch.zeros( batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype, ) ref_img_freqs_cis = torch.zeros( batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype, ) img_freqs_cis = torch.zeros( batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype, ) for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate( zip( l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths, ) ): cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] ref_img_freqs_cis[i, : sum(ref_img_len)] = freqs_cis[ i, cap_seq_len : cap_seq_len + sum(ref_img_len) ] if isinstance(img_len, list): img_len = sum(img_len) img_freqs_cis[i, :img_len] = freqs_cis[ i, cap_seq_len + sum(ref_img_len) : cap_seq_len + sum(ref_img_len) + img_len, ] return ( cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths, ) def force_scheduler(cache_dic, current): if cache_dic["fresh_ratio"] == 0: # FORA linear_step_weight = 0.0 else: # TokenCache linear_step_weight = 0.0 step_factor = torch.tensor( 1 - linear_step_weight + 2 * linear_step_weight * current["step"] / current["num_steps"] ) threshold = torch.round(cache_dic["fresh_threshold"] / step_factor) # no force constrain for sensitive steps, cause the performance is good enough. # you may have a try. cache_dic["cal_threshold"] = threshold # return threshold def cal_type(cache_dic, current): """ Determine calculation type for this step """ if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]): # FORA:Uniform first_step = current["step"] == 0 else: # ToCa: First enhanced first_step = current["step"] < cache_dic["first_enhance"] if not first_step: fresh_interval = cache_dic["cal_threshold"] else: fresh_interval = cache_dic["fresh_threshold"] if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1): current["type"] = "full" cache_dic["cache_counter"] = 0 current["activated_steps"].append(current["step"]) force_scheduler(cache_dic, current) elif cache_dic["taylor_cache"]: cache_dic["cache_counter"] += 1 current["type"] = "Taylor" elif ( cache_dic["cache_counter"] % 2 == 1 ): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive cache_dic["cache_counter"] += 1 current["type"] = "ToCa" # 'cache_noise' 'ToCa' 'FORA' elif cache_dic["Delta-DiT"]: cache_dic["cache_counter"] += 1 current["type"] = "Delta-Cache" else: cache_dic["cache_counter"] += 1 current["type"] = "ToCa" def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor): """ Compute derivative approximation. :param cache_dic: Cache dictionary :param current: Information of the current step """ difference_distance = ( current["activated_steps"][-1] - current["activated_steps"][-2] ) updated_taylor_factors = {} updated_taylor_factors[0] = feature for i in range(cache_dic["max_order"]): if ( cache_dic["cache"][-1][current["stream"]][current["layer"]][ current["module"] ].get(i, None) is not None ) and (current["step"] > cache_dic["first_enhance"] - 2): updated_taylor_factors[i + 1] = ( updated_taylor_factors[i] - cache_dic["cache"][-1][current["stream"]][current["layer"]][ current["module"] ][i] ) / difference_distance else: break cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = ( updated_taylor_factors ) def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor: """ Compute Taylor expansion error. :param cache_dic: Cache dictionary :param current: Information of the current step """ x = current["step"] - current["activated_steps"][-1] # x = current['t'] - current['activated_times'][-1] output = 0 for i in range( len( cache_dic["cache"][-1][current["stream"]][current["layer"]][ current["module"] ] ) ): output += ( (1 / math.factorial(i)) * cache_dic["cache"][-1][current["stream"]][current["layer"]][ current["module"] ][i] * (x**i) ) return output def taylor_cache_init(cache_dic: Dict, current: Dict): """ Initialize Taylor cache and allocate storage for different-order derivatives in the Taylor cache. :param cache_dic: Cache dictionary :param current: Information of the current step """ if (current["step"] == 0) and (cache_dic["taylor_cache"]): cache_dic["cache"][-1][current["stream"]][current["layer"]][ current["module"] ] = {} logger = logging.get_logger(__name__) class OmniGen2TransformerBlock(nn.Module): """ Transformer block for OmniGen2 model. This block implements a transformer layer with: - Multi-head attention with flash attention - Feed-forward network with SwiGLU activation - RMS normalization - Optional modulation for conditional generation Args: dim: Dimension of the input and output tensors num_attention_heads: Number of attention heads num_kv_heads: Number of key-value heads multiple_of: Multiple of which the hidden dimension should be ffn_dim_multiplier: Multiplier for the feed-forward network dimension norm_eps: Epsilon value for normalization layers modulation: Whether to use modulation for conditional generation use_fused_rms_norm: Whether to use fused RMS normalization use_fused_swiglu: Whether to use fused SwiGLU activation """ def __init__( self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, ) -> None: """Initialize the transformer block.""" super().__init__() self.head_dim = dim // num_attention_heads self.modulation = modulation try: processor = OmniGen2AttnProcessorFlash2Varlen() except ImportError: processor = OmniGen2AttnProcessor() # Initialize attention layer self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=dim // num_attention_heads, qk_norm="rms_norm", heads=num_attention_heads, kv_heads=num_kv_heads, eps=1e-5, bias=False, out_bias=False, processor=processor, ) # Initialize feed-forward network self.feed_forward = LuminaFeedForward( dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, ) # Initialize normalization layers if modulation: self.norm1 = LuminaRMSNormZero( embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True ) else: self.norm1 = RMSNorm(dim, eps=norm_eps) self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) self.norm2 = RMSNorm(dim, eps=norm_eps) self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) self.initialize_weights() def initialize_weights(self) -> None: """ Initialize the weights of the transformer block. Uses Xavier uniform initialization for linear layers and zero initialization for biases. """ nn.init.xavier_uniform_(self.attn.to_q.weight) nn.init.xavier_uniform_(self.attn.to_k.weight) nn.init.xavier_uniform_(self.attn.to_v.weight) nn.init.xavier_uniform_(self.attn.to_out[0].weight) nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) if self.modulation: nn.init.zeros_(self.norm1.linear.weight) nn.init.zeros_(self.norm1.linear.bias) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the transformer block. Args: hidden_states: Input hidden states tensor attention_mask: Attention mask tensor image_rotary_emb: Rotary embeddings for image tokens temb: Optional timestep embedding tensor Returns: torch.Tensor: Output hidden states after transformer block processing """ enable_taylorseer = getattr(self, "enable_taylorseer", False) if enable_taylorseer: if self.modulation: if temb is None: raise ValueError("temb must be provided when modulation is enabled") if self.current["type"] == "full": self.current["module"] = "total" taylor_cache_init(cache_dic=self.cache_dic, current=self.current) norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1( hidden_states, temb ) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + gate_msa.unsqueeze( 1 ).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward( self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)) ) hidden_states = hidden_states + gate_mlp.unsqueeze( 1 ).tanh() * self.ffn_norm2(mlp_output) derivative_approximation( cache_dic=self.cache_dic, current=self.current, feature=hidden_states, ) elif self.current["type"] == "Taylor": self.current["module"] = "total" hidden_states = taylor_formula( cache_dic=self.cache_dic, current=self.current ) else: norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) else: if self.modulation: if temb is None: raise ValueError("temb must be provided when modulation is enabled") norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1( hidden_states, temb ) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + gate_msa.unsqueeze( 1 ).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward( self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)) ) hidden_states = hidden_states + gate_mlp.unsqueeze( 1 ).tanh() * self.ffn_norm2(mlp_output) else: norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) return hidden_states class OmniGen2Transformer3DModel( ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin ): """ OmniGen2 Transformer 3D Model (modified to output frame sequences). A transformer-based diffusion model for image generation with: - Patch-based image processing - Rotary position embeddings - Multi-head attention - Conditional generation support Args: patch_size: Size of image patches in_channels: Number of input channels out_channels: Number of output channels (defaults to in_channels) hidden_size: Size of hidden layers num_layers: Number of transformer layers num_refiner_layers: Number of refiner layers num_attention_heads: Number of attention heads num_kv_heads: Number of key-value heads multiple_of: Multiple of which the hidden dimension should be ffn_dim_multiplier: Multiplier for feed-forward network dimension norm_eps: Epsilon value for normalization layers axes_dim_rope: Dimensions for rotary position embeddings axes_lens: Lengths for rotary position embeddings text_feat_dim: Dimension of text features timestep_scale: Scale factor for timestep embeddings use_fused_rms_norm: Whether to use fused RMS normalization use_fused_swiglu: Whether to use fused SwiGLU activation """ _supports_gradient_checkpointing = True _no_split_modules = ["Omnigen2TransformerBlock"] _skip_layerwise_casting_patterns = ["x_embedder", "norm"] @register_to_config def __init__( self, patch_size: int = 2, in_channels: int = 16, out_channels: Optional[int] = None, hidden_size: int = 2304, num_layers: int = 26, num_refiner_layers: int = 2, num_attention_heads: int = 24, num_kv_heads: int = 8, multiple_of: int = 256, ffn_dim_multiplier: Optional[float] = None, norm_eps: float = 1e-5, axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), axes_lens: Tuple[int, int, int] = (300, 512, 512), text_feat_dim: int = 1024, timestep_scale: float = 1.0, ) -> None: """Initialize the OmniGen2 transformer model.""" super().__init__() # Validate configuration if (hidden_size // num_attention_heads) != sum(axes_dim_rope): raise ValueError( f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" ) self.out_channels = out_channels or in_channels # Initialize embeddings self.rope_embedder = OmniGen2RotaryPosEmbed( theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size, ) self.x_embedder = nn.Linear( in_features=patch_size * patch_size * in_channels, out_features=hidden_size, ) self.ref_image_patch_embedder = nn.Linear( in_features=patch_size * patch_size * in_channels, out_features=hidden_size, ) self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( hidden_size=hidden_size, text_feat_dim=text_feat_dim, norm_eps=norm_eps, timestep_scale=timestep_scale, ) # Initialize transformer blocks self.noise_refiner = nn.ModuleList( [ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, ) for _ in range(num_refiner_layers) ] ) self.ref_image_refiner = nn.ModuleList( [ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, ) for _ in range(num_refiner_layers) ] ) self.context_refiner = nn.ModuleList( [ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, ) for _ in range(num_refiner_layers) ] ) # 3. Transformer blocks self.layers = nn.ModuleList( [ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, ) for _ in range(num_layers) ] ) # 4. Output norm & projection self.norm_out = LuminaLayerNormContinuous( embedding_dim=hidden_size, conditioning_embedding_dim=min(hidden_size, 1024), elementwise_affine=False, eps=1e-6, bias=True, out_dim=patch_size * patch_size * self.out_channels, ) # Add learnable embeddings to distinguish different images self.image_index_embedding = nn.Parameter( torch.randn(5, hidden_size) ) # support max 5 ref images self.gradient_checkpointing = False self.initialize_weights() # TeaCache settings self.enable_teacache = False self.teacache_rel_l1_thresh = 0.05 self.teacache_params = TeaCacheParams() coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487] self.rescale_func = np.poly1d(coefficients) def initialize_weights(self) -> None: """ Initialize the weights of the model. Uses Xavier uniform initialization for linear layers. """ nn.init.xavier_uniform_(self.x_embedder.weight) nn.init.constant_(self.x_embedder.bias, 0.0) nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) nn.init.zeros_(self.norm_out.linear_1.weight) nn.init.zeros_(self.norm_out.linear_1.bias) nn.init.zeros_(self.norm_out.linear_2.weight) nn.init.zeros_(self.norm_out.linear_2.bias) nn.init.normal_(self.image_index_embedding, std=0.02) def img_patch_embed_and_refine( self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, ): batch_size = len(hidden_states) if isinstance(l_effective_img_len[0], list): l_effective_img_len_summed = [sum(ln) for ln in l_effective_img_len] else: l_effective_img_len_summed = l_effective_img_len max_combined_img_len = max( [ img_len + sum(ref_img_len) for img_len, ref_img_len in zip( l_effective_img_len_summed, l_effective_ref_img_len ) ] ) hidden_states = self.x_embedder(hidden_states) ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) for i in range(batch_size): shift = 0 for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): ref_image_hidden_states[i, shift : shift + ref_img_len, :] = ( ref_image_hidden_states[i, shift : shift + ref_img_len, :] + self.image_index_embedding[j] ) shift += ref_img_len for layer in self.noise_refiner: hidden_states = layer( hidden_states, padded_img_mask, noise_rotary_emb, temb ) flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) num_ref_images = len(flat_l_effective_ref_img_len) max_ref_img_len = max(flat_l_effective_ref_img_len) batch_ref_img_mask = ref_image_hidden_states.new_zeros( num_ref_images, max_ref_img_len, dtype=torch.bool ) batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros( num_ref_images, max_ref_img_len, self.config.hidden_size ) batch_ref_img_rotary_emb = hidden_states.new_zeros( num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype, ) batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) # sequence of ref imgs to batch idx = 0 for i in range(batch_size): shift = 0 for ref_img_len in l_effective_ref_img_len[i]: batch_ref_img_mask[idx, :ref_img_len] = True batch_ref_image_hidden_states[idx, :ref_img_len] = ( ref_image_hidden_states[i, shift : shift + ref_img_len] ) batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[ i, shift : shift + ref_img_len ] batch_temb[idx] = temb[i] shift += ref_img_len idx += 1 # refine ref imgs separately for layer in self.ref_image_refiner: batch_ref_image_hidden_states = layer( batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb, ) # batch of ref imgs to sequence idx = 0 for i in range(batch_size): shift = 0 for ref_img_len in l_effective_ref_img_len[i]: ref_image_hidden_states[i, shift : shift + ref_img_len] = ( batch_ref_image_hidden_states[idx, :ref_img_len] ) shift += ref_img_len idx += 1 combined_img_hidden_states = hidden_states.new_zeros( batch_size, max_combined_img_len, self.config.hidden_size ) for i, (ref_img_len, img_len) in enumerate( zip(l_effective_ref_img_len, l_effective_img_len_summed) ): combined_img_hidden_states[i, : sum(ref_img_len)] = ref_image_hidden_states[ i, : sum(ref_img_len) ] combined_img_hidden_states[ i, sum(ref_img_len) : sum(ref_img_len) + img_len ] = hidden_states[i, :img_len] return combined_img_hidden_states def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): batch_size = len(hidden_states) p = self.config.patch_size device = hidden_states[0].device if len(hidden_states[0].shape) == 3: img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] else: img_sizes = [ [(img.size(1), img.size(2)) for img in imgs] for imgs in hidden_states ] l_effective_img_len = [ [(H // p) * (W // p) for (H, W) in _img_sizes] for _img_sizes in img_sizes ] if ref_image_hidden_states is not None: ref_img_sizes = [ [(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states ] l_effective_ref_img_len = [ [ (ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes ] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes ] else: ref_img_sizes = [None for _ in range(batch_size)] l_effective_ref_img_len = [[0] for _ in range(batch_size)] max_ref_img_len = max( [sum(ref_img_len) for ref_img_len in l_effective_ref_img_len] ) if len(hidden_states[0].shape) == 4: max_img_len = max([sum(img_len) for img_len in l_effective_img_len]) else: max_img_len = max(l_effective_img_len) # ref image patch embeddings flat_ref_img_hidden_states = [] for i in range(batch_size): if ref_img_sizes[i] is not None: imgs = [] for ref_img in ref_image_hidden_states[i]: C, H, W = ref_img.size() ref_img = rearrange( ref_img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p ) imgs.append(ref_img) img = torch.cat(imgs, dim=0) flat_ref_img_hidden_states.append(img) else: flat_ref_img_hidden_states.append(None) # image patch embeddings flat_hidden_states = [] if len(hidden_states[0].shape) == 4: # New case for i in range(batch_size): # Process each time step and concatenate batch_img_patches = [] for img in hidden_states[i]: C, H, W = img.size() img = rearrange( img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p ) batch_img_patches.append(img) # Concatenate patches for the current batch item across time flat_hidden_states.append(torch.cat(batch_img_patches, dim=0)) else: # Default for i in range(batch_size): img = hidden_states[i] C, H, W = img.size() img = rearrange(img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p) flat_hidden_states.append(img) padded_ref_img_hidden_states = torch.zeros( batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype, ) padded_ref_img_mask = torch.zeros( batch_size, max_ref_img_len, dtype=torch.bool, device=device ) for i in range(batch_size): if ref_img_sizes[i] is not None: padded_ref_img_hidden_states[i, : sum(l_effective_ref_img_len[i])] = ( flat_ref_img_hidden_states[i] ) padded_ref_img_mask[i, : sum(l_effective_ref_img_len[i])] = True padded_hidden_states = torch.zeros( batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype, ) padded_img_mask = torch.zeros( batch_size, max_img_len, dtype=torch.bool, device=device ) for i in range(batch_size): if len(hidden_states[0].shape) == 4: # New case padded_hidden_states[i, : sum(l_effective_img_len[i])] = ( flat_hidden_states[i] ) padded_img_mask[i, : sum(l_effective_img_len[i])] = True else: padded_hidden_states[i, : l_effective_img_len[i]] = flat_hidden_states[ i ] padded_img_mask[i, : l_effective_img_len[i]] = True return ( padded_hidden_states, padded_ref_img_hidden_states, padded_img_mask, padded_ref_img_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, ) def forward( self, hidden_states: Union[torch.Tensor, List[torch.Tensor]], timestep: torch.Tensor, text_hidden_states: torch.Tensor, freqs_cis: torch.Tensor, text_attention_mask: torch.Tensor, ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = False, ) -> Union[torch.Tensor, Transformer2DModelOutput]: enable_taylorseer = getattr(self, "enable_taylorseer", False) if enable_taylorseer: cal_type(self.cache_dic, self.current) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if ( attention_kwargs is not None and attention_kwargs.get("scale", None) is not None ): logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) # 1. Condition, positional & patch embedding batch_size = len(hidden_states) is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) if is_hidden_states_tensor: assert hidden_states.ndim == 4 hidden_states = [_hidden_states for _hidden_states in hidden_states] device = hidden_states[0].device temb, text_hidden_states = self.time_caption_embed( timestep, text_hidden_states, hidden_states[0].dtype ) ( hidden_states, ref_image_hidden_states, img_mask, ref_img_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) ( context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb, rotary_emb, encoder_seq_lengths, seq_lengths, ) = self.rope_embedder( freqs_cis, text_attention_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device, ) # 2. Context refinement for layer in self.context_refiner: text_hidden_states = layer( text_hidden_states, text_attention_mask, context_rotary_emb ) combined_img_hidden_states = self.img_patch_embed_and_refine( hidden_states, ref_image_hidden_states, img_mask, ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, ) # 3. Joint Transformer blocks max_seq_len = max(seq_lengths) attention_mask = hidden_states.new_zeros( batch_size, max_seq_len, dtype=torch.bool ) joint_hidden_states = hidden_states.new_zeros( batch_size, max_seq_len, self.config.hidden_size ) for i, (encoder_seq_len, seq_len) in enumerate( zip(encoder_seq_lengths, seq_lengths) ): attention_mask[i, :seq_len] = True joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[ i, :encoder_seq_len ] joint_hidden_states[i, encoder_seq_len:seq_len] = ( combined_img_hidden_states[i, : seq_len - encoder_seq_len] ) hidden_states = joint_hidden_states if self.enable_teacache: teacache_hidden_states = hidden_states.clone() teacache_temb = temb.clone() modulated_inp, _, _, _ = self.layers[0].norm1( teacache_hidden_states, teacache_temb ) if self.teacache_params.is_first_or_last_step: should_calc = True self.teacache_params.accumulated_rel_l1_distance = 0 else: self.teacache_params.accumulated_rel_l1_distance += self.rescale_func( ( (modulated_inp - self.teacache_params.previous_modulated_inp) .abs() .mean() / self.teacache_params.previous_modulated_inp.abs().mean() ) .cpu() .item() ) if ( self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh ): should_calc = False else: should_calc = True self.teacache_params.accumulated_rel_l1_distance = 0 self.teacache_params.previous_modulated_inp = modulated_inp if self.enable_teacache: if not should_calc: hidden_states += self.teacache_params.previous_residual else: ori_hidden_states = hidden_states.clone() for layer_idx, layer in enumerate(self.layers): if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( layer, hidden_states, attention_mask, rotary_emb, temb ) else: hidden_states = layer( hidden_states, attention_mask, rotary_emb, temb ) self.teacache_params.previous_residual = ( hidden_states - ori_hidden_states ) else: if enable_taylorseer: self.current["stream"] = "layers_stream" for layer_idx, layer in enumerate(self.layers): if enable_taylorseer: layer.current = self.current layer.cache_dic = self.cache_dic layer.enable_taylorseer = True self.current["layer"] = layer_idx if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( layer, hidden_states, attention_mask, rotary_emb, temb ) else: hidden_states = layer( hidden_states, attention_mask, rotary_emb, temb ) # 4. Output norm & projection hidden_states = self.norm_out(hidden_states, temb) p = self.config.patch_size output = [] for i, (img_size, img_len, seq_len) in enumerate( zip(img_sizes, l_effective_img_len, seq_lengths) ): if isinstance(img_len, list): batch_output = [] cur_st = seq_len - sum(img_len) for j in range(len(img_len)): height, width = img_size[j] cur_len = img_len[j] batch_output.append( rearrange( hidden_states[i][cur_st : cur_st + cur_len], "(h w) (p1 p2 c) -> c (h p1) (w p2)", h=height // p, w=width // p, p1=p, p2=p, ) ) cur_st += cur_len output.append(torch.stack(batch_output, dim=0)) else: height, width = img_size output.append( rearrange( hidden_states[i][seq_len - img_len : seq_len], "(h w) (p1 p2 c) -> c (h p1) (w p2)", h=height // p, w=width // p, p1=p, p2=p, ) ) if is_hidden_states_tensor: output = torch.stack(output, dim=0) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if enable_taylorseer: self.current["step"] += 1 if not return_dict: return output return Transformer2DModelOutput(sample=output)