diff --git "a/transformer_fofpred.py" "b/transformer_fofpred.py" new file mode 100644--- /dev/null +++ "b/transformer_fofpred.py" @@ -0,0 +1,3211 @@ +"""FOFPred Transformer modified from OmniGen2 DiT.""" + +import importlib.util +import itertools +import math +import sys +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from einops import rearrange, repeat + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + + +_triton_available, _triton_version = _is_package_available("triton") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") + + +def is_triton_available(): + return _triton_available + + +def is_flash_attn_available(): + return _flash_attn_available + + +if is_triton_available(): + import triton + import triton.language as tl + + def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + + return decorator + + if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_bwd, custom_fwd # type: ignore[attr-defined] + else: + deprecated = False + from torch.cuda.amp import custom_bwd, custom_fwd + + custom_fwd = custom_amp_decorator(custom_fwd, deprecated) + custom_bwd = custom_amp_decorator(custom_bwd, deprecated) + + def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs = [] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block = 1024 + # Default to warp size 32 if not defined by device + warp_size = getattr( + torch.cuda.get_device_properties(torch.cuda.current_device()), + "warp_size", + 32, + ) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + warp_count = 1 + while warp_count * warp_size <= max_threads_per_block: + configs.append(triton.Config({}, num_warps=warp_count)) + warp_count *= 2 + return configs + + def layer_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, + ): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm( + x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps + ).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = F.layer_norm( + x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps + ).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + def rms_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, + ): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = ( + (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + ).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = ( + (x * rstd * weight1) + bias1 + if bias1 is not None + else (x * rstd * weight1) + ).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + @triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], + ) + # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) + # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) + @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) + @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) + @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) + @triton.jit + def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) + > dropout_p + ) + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) + > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store( + DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N + ) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + def _layer_norm_fwd( + x, + weight, + bias, + eps, + residual=None, + x1=None, + weight1=None, + bias1=None, + dropout_p=0.0, + rowscale=None, + out_dtype=None, + residual_dtype=None, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None, + ): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + else: + assert out.shape == x.shape + assert out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + if residual_out is None: + residual_out = torch.empty( + M, + N, + device=x.device, + dtype=residual_dtype if residual_dtype is not None else x.dtype, + ) + else: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = ( + torch.empty((M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty( + M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool + ) + else: + dropout_mask = None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + out, + weight, + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if dropout_mask is not None and x1 is not None: + dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) + else: + dropout_mask1 = None + return ( + out, + y1, + mean, + rstd, + residual_out if residual_out is not None else x, + seeds, + dropout_mask, + dropout_mask1, + ) + + @triton.autotune( + configs=triton_autotune_configs(), + key=[ + "N", + "HAS_DRESIDUAL", + "STORE_DRESIDUAL", + "IS_RMS_NORM", + "HAS_BIAS", + "HAS_DROPOUT", + ], + ) + # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) + # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) + # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) + @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) + @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) + @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) + @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) + @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) + @triton.jit + def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, + DRESIDUAL_IN, + ROWSCALE, + SEEDS, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dy1_row, + stride_dx1_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, + zero_centered_weight, + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + ): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = ( + tl.rand( + tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7 + ) + > dropout_p + ) + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = ( + tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) + > dropout_p + ) + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) + + def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + dy1=None, + weight1=None, + bias1=None, + seeds=None, + dropout_p=0.0, + rowscale=None, + has_residual=False, + has_x1=False, + zero_centered_weight=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, + ): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if dy1 is not None: + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = ( + torch.empty_like(x) + if has_residual + and ( + dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1 + ) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None + y = ( + torch.empty(M, N, dtype=dy.dtype, device=dy.device) + if recompute_output + else None + ) + if recompute_output: + assert weight1 is None, ( + "recompute_output is not supported with parallel LayerNorm" + ) + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, + dresidual_in, + rowscale, + seeds, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + dropout_p > 0.0, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None + # Don't need to compute dresidual_in separately in this case + if ( + has_residual + and dx.dtype == x.dtype + and dropout_p == 0.0 + and rowscale is None + ): + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return ( + (dx, dw, db, dresidual_in, dx1, dw1, db1) + if not recompute_output + else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) + ) + + class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None, + ): + x_shape_og = x.shape + # Check for zero sequence length + if x.numel() == 0: + ctx.zero_seq_length = True + # Only save minimal required tensors for backward + # ctx.save_for_backward(weight, bias, weight1, bias1) + ctx.x_shape_og = x_shape_og + ctx.weight_shape = weight.shape + ctx.weight_dtype = weight.dtype + ctx.weight_device = weight.device + + ctx.has_bias = bias is not None + ctx.bias_shape = bias.shape if bias is not None else None + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.bias_device = bias.device if bias is not None else None + + ctx.has_weight1 = weight1 is not None + ctx.weight1_shape = weight1.shape if weight1 is not None else None + ctx.weight1_dtype = weight1.dtype if weight1 is not None else None + ctx.weight1_device = weight1.device if weight1 is not None else None + + ctx.has_bias1 = bias1 is not None + ctx.bias1_shape = bias1.shape if bias1 is not None else None + ctx.bias1_dtype = bias1.dtype if bias1 is not None else None + ctx.bias1_device = bias1.device if bias1 is not None else None + + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.dropout_p = dropout_p + + # Handle output tensors with correct dtype + y = x # Preserve input tensor properties + y1 = torch.empty_like(x) if x1 is not None else None + + # Only create residual_out if prenorm is True + residual_out = ( + torch.empty( + x.shape, + dtype=torch.float32 if residual_in_fp32 else x.dtype, + device=x.device, + ) + if prenorm + else None + ) + + # Handle dropout masks + dropout_mask = None + dropout_mask1 = None + if return_dropout_mask: + dropout_mask = torch.empty_like(x, dtype=torch.uint8) + if x1 is not None: + dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) + + # Return based on configuration + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + ctx.zero_seq_length = False + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, ( + "rowscale is not supported with parallel LayerNorm" + ) + x1 = x1.reshape(-1, x1.shape[-1]) + if x1.stride(-1) != 1: + x1 = x1.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + if weight1 is not None: + weight1 = weight1.contiguous() + if bias1 is not None: + bias1 = bias1.contiguous() + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = ( + _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out, + ) + ) + ctx.save_for_backward( + residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd + ) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.dropout_p = dropout_p + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.zero_centered_weight = zero_centered_weight + y = y.reshape(x_shape_og) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = ( + residual_out.reshape(x_shape_og) if residual_out is not None else None + ) + dropout_mask = ( + dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + ) + dropout_mask1 = ( + dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + ) + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + @staticmethod + def backward(ctx, dy, *args): + if ctx.zero_seq_length: + return ( + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), + torch.zeros( + ctx.weight_shape, + dtype=ctx.weight_dtype, + device=ctx.weight_device, + ), + torch.zeros( + ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device + ) + if ctx.has_bias + else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) + if ctx.has_residual + else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) + if ctx.has_x1 and ctx.dropout_p > 0.0 + else None, + torch.zeros( + ctx.weight1_shape, + dtype=ctx.weight1_dtype, + device=ctx.weight1_device, + ) + if ctx.has_weight1 + else None, + torch.zeros( + ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device + ) + if ctx.has_bias1 + else None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ( + ctx.saved_tensors + ) + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + if dy1.stride(-1) != 1: + dy1 = dy1.contiguous() + assert dy1.shape == x.shape + else: + dy1 = None + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + + dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, + ctx.has_residual, + ctx.has_x1, + ctx.zero_centered_weight, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None, + ): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out, + residual_out, + ) + + def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out=None, + residual_out=None, + ): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out, + residual_out, + ) + + class RMSNorm(torch.nn.Module): + def __init__( + self, + hidden_size, + eps=1e-5, + dropout_p=0.0, + zero_centered_weight=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if dropout_p > 0.0: + self.drop = torch.nn.Dropout(dropout_p) + else: + self.drop = None + self.zero_centered_weight = zero_centered_weight + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + if not self.zero_centered_weight: + torch.nn.init.ones_(self.weight) + else: + torch.nn.init.zeros_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + dropout_p=self.drop.p + if self.drop is not None and self.training + else 0.0, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + zero_centered_weight=self.zero_centered_weight, + ) + + +else: + from torch.nn import RMSNorm + + +if is_flash_attn_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + from flash_attn.ops.activations import swiglu +else: + + def swiglu(x, y): + return F.silu(x.float(), inplace=False).to(x.dtype) * y + + warnings.warn( + "Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance" + ) + + +@dataclass +class TeaCacheParams: + """ + TeaCache parameters for `OmniGen2Transformer3DModel` + See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding + + Args: + previous_residual (Optional[torch.Tensor]): + The tensor difference between the output and the input of the transformer layers from the previous timestep. + previous_modulated_inp (Optional[torch.Tensor]): + The modulated input from the previous timestep used to indicate the change of the transformer layer's output. + accumulated_rel_l1_distance (float): + The accumulated relative L1 distance. + is_first_or_last_step (bool): + Whether the current timestep is the first or last step. + """ + + previous_residual: Optional[torch.Tensor] = None + previous_modulated_inp: Optional[torch.Tensor] = None + accumulated_rel_l1_distance: float = 0 + is_first_or_last_step: bool = False + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + self.initialize_weights() + + def initialize_weights(self): + nn.init.normal_(self.linear_1.weight, std=0.02) + nn.init.zeros_(self.linear_1.bias) + nn.init.normal_(self.linear_2.weight, std=0.02) + nn.init.zeros_(self.linear_2.bias) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind( + -1 + ) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen and CogView4 + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind( + -2 + ) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError( + f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2." + ) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = torch.view_as_complex( + x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2) + ) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class OmniGen2AttnProcessorFlash2Varlen: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not is_flash_attn_available(): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires flash_attn. " + "Please install flash_attn." + ) + + def _upad_input( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + num_heads: int, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[int, int], + ]: + """ + Unpad the input tensors for flash attention. + + Args: + query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) + key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + attention_mask: Attention mask tensor of shape (batch_size, seq_len) + query_length: Length of the query sequence + num_heads: Number of attention heads + + Returns: + Tuple containing: + - Unpadded query tensor + - Unpadded key tensor + - Unpadded value tensor + - Query indices + - Tuple of cumulative sequence lengths for query and key + - Tuple of maximum sequence lengths for query and key + """ + + def _get_unpad_data( + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """Helper function to get unpadding data from attention mask.""" + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + return indices, cu_seqlens, max_seqlen_in_batch + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + # Unpad key and value layers + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + + # Handle different query length cases + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = ( + math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + ) + else: + softmax_scale = attn.scale + + # Unpad input for flash attention + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query, key, value, attention_mask, sequence_length, attn.heads + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + # Handle different number of heads + if kv_heads < attn.heads: + key_states = repeat( + key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads + ) + value_states = repeat( + value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads + ) + + # Apply flash attention + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + + # Pad output and apply final transformations + hidden_states = pad_input( + attn_output_unpad, indices_q, batch_size, sequence_length + ) + hidden_states = hidden_states.flatten(-2) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class OmniGen2AttnProcessor: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor is optimized for PyTorch 2.0 and implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + + Raises: + ImportError: If PyTorch version is less than 2.0 + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " + "Please upgrade PyTorch to version 2.0 or later." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = ( + math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + ) + else: + softmax_scale = attn.scale + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + if attention_mask is not None: + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6 + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, + embedding_dim: int, + norm_eps: float, + norm_elementwise_affine: bool, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + + self.norm = RMSNorm(embedding_dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + return x, gate_msa, scale_mlp, gate_mlp + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm( + embedding_dim, eps=eps, elementwise_affine=elementwise_affine + ) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + self.linear_2 = None + if out_dim is not None: + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + self.swiglu = swiglu + + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + + def forward(self, x): + h1, h2 = self.linear_1(x), self.linear_3(x) + return self.linear_2(self.swiglu(h1, h2)) + + +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + text_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + timestep_scale: float = 1.0, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, + flip_sin_to_cos=True, + downscale_freq_shift=0.0, + scale=timestep_scale, + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + RMSNorm(text_feat_dim, eps=norm_eps), + nn.Linear(text_feat_dim, hidden_size, bias=True), + ) + + self._initialize_weights() + + def _initialize_weights(self): + nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) + nn.init.zeros_(self.caption_embedder[1].bias) + + def forward( + self, + timestep: torch.Tensor, + text_hidden_states: torch.Tensor, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(text_hidden_states) + return time_embed, caption_embed + + +class OmniGen2RotaryPosEmbed(nn.Module): + def __init__( + self, + theta: int, + axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int] = (300, 512, 512), + patch_size: int = 2, + ): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + @staticmethod + def get_freqs_cis( + axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int + ) -> List[torch.Tensor]: + freqs_cis = [] + freqs_dtype = ( + torch.float32 if torch.backends.mps.is_available() else torch.float64 + ) + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) + freqs_cis.append(emb) + return freqs_cis + + def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + + result = [] + for i in range(len(self.axes_dim)): + freqs = freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append( + torch.gather( + freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index + ) + ) + return torch.cat(result, dim=-1).to(device) + + def forward( + self, + freqs_cis, + attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ): + batch_size = len(attention_mask) + p = self.patch_size + + encoder_seq_len = attention_mask.shape[1] + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + + if isinstance(l_effective_img_len[0], list): # Check for t-dim case + seq_lengths = [ + cap_len + sum(ref_img_len) + sum(img_len) + for cap_len, ref_img_len, img_len in zip( + l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len + ) + ] + else: # Original case + seq_lengths = [ + cap_len + sum(ref_img_len) + img_len + for cap_len, ref_img_len, img_len in zip( + l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len + ) + ] + + max_seq_len = max(seq_lengths) + max_ref_img_len = max( + [sum(ref_img_len) for ref_img_len in l_effective_ref_img_len] + ) + if isinstance(l_effective_img_len[0], list): + max_img_len = max([sum(ln) for ln in l_effective_img_len]) + else: + max_img_len = max(l_effective_img_len) + + # Create position IDs + position_ids = torch.zeros( + batch_size, max_seq_len, 3, dtype=torch.int32, device=device + ) + + for i, (cap_seq_len, seq_len) in enumerate( + zip(l_effective_cap_len, seq_lengths) + ): + # add text position ids + position_ids[i, :cap_seq_len] = repeat( + torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3" + ) + + pe_shift = cap_seq_len + pe_shift_len = cap_seq_len + + if ref_img_sizes[i] is not None: + for ref_img_size, ref_img_len in zip( + ref_img_sizes[i], l_effective_ref_img_len[i] + ): + H, W = ref_img_size + ref_H_tokens, ref_W_tokens = H // p, W // p + assert ref_H_tokens * ref_W_tokens == ref_img_len + # add image position ids + + row_ids = repeat( + torch.arange(ref_H_tokens, dtype=torch.int32, device=device), + "h -> h w", + w=ref_W_tokens, + ).flatten() + col_ids = repeat( + torch.arange(ref_W_tokens, dtype=torch.int32, device=device), + "w -> h w", + h=ref_H_tokens, + ).flatten() + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 0] = ( + pe_shift + ) + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 1] = ( + row_ids + ) + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 2] = ( + col_ids + ) + + pe_shift += max(ref_H_tokens, ref_W_tokens) + pe_shift_len += ref_img_len + + if isinstance(l_effective_img_len[i], list): # New case + for img_size, img_len in zip(img_sizes[i], l_effective_img_len[i]): + H, W = img_size + H_tokens, W_tokens = H // p, W // p + assert H_tokens * W_tokens == img_len + + row_ids = repeat( + torch.arange(H_tokens, dtype=torch.int32, device=device), + "h -> h w", + w=W_tokens, + ).flatten() + col_ids = repeat( + torch.arange(W_tokens, dtype=torch.int32, device=device), + "w -> h w", + h=H_tokens, + ).flatten() + + end_idx = pe_shift_len + img_len + + position_ids[i, pe_shift_len:end_idx, 0] = pe_shift + position_ids[i, pe_shift_len:end_idx, 1] = row_ids + position_ids[i, pe_shift_len:end_idx, 2] = col_ids + + pe_shift += max(H_tokens, W_tokens) + pe_shift_len = end_idx + else: # Original case + H, W = img_sizes[i] + H_tokens, W_tokens = H // p, W // p + assert H_tokens * W_tokens == l_effective_img_len[i] + + row_ids = repeat( + torch.arange(H_tokens, dtype=torch.int32, device=device), + "h -> h w", + w=W_tokens, + ).flatten() + col_ids = repeat( + torch.arange(W_tokens, dtype=torch.int32, device=device), + "w -> h w", + h=H_tokens, + ).flatten() + + assert pe_shift_len + l_effective_img_len[i] == seq_len + position_ids[i, pe_shift_len:seq_len, 0] = pe_shift + position_ids[i, pe_shift_len:seq_len, 1] = row_ids + position_ids[i, pe_shift_len:seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) + + # create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + batch_size, + encoder_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + ref_img_freqs_cis = torch.zeros( + batch_size, + max_ref_img_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + img_freqs_cis = torch.zeros( + batch_size, + max_img_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + + for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate( + zip( + l_effective_cap_len, + l_effective_ref_img_len, + l_effective_img_len, + seq_lengths, + ) + ): + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + ref_img_freqs_cis[i, : sum(ref_img_len)] = freqs_cis[ + i, cap_seq_len : cap_seq_len + sum(ref_img_len) + ] + if isinstance(img_len, list): + img_len = sum(img_len) + img_freqs_cis[i, :img_len] = freqs_cis[ + i, + cap_seq_len + sum(ref_img_len) : cap_seq_len + + sum(ref_img_len) + + img_len, + ] + + return ( + cap_freqs_cis, + ref_img_freqs_cis, + img_freqs_cis, + freqs_cis, + l_effective_cap_len, + seq_lengths, + ) + + +def force_scheduler(cache_dic, current): + if cache_dic["fresh_ratio"] == 0: + # FORA + linear_step_weight = 0.0 + else: + # TokenCache + linear_step_weight = 0.0 + step_factor = torch.tensor( + 1 + - linear_step_weight + + 2 * linear_step_weight * current["step"] / current["num_steps"] + ) + threshold = torch.round(cache_dic["fresh_threshold"] / step_factor) + + # no force constrain for sensitive steps, cause the performance is good enough. + # you may have a try. + + cache_dic["cal_threshold"] = threshold + # return threshold + + +def cal_type(cache_dic, current): + """ + Determine calculation type for this step + """ + if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]): + # FORA:Uniform + first_step = current["step"] == 0 + else: + # ToCa: First enhanced + first_step = current["step"] < cache_dic["first_enhance"] + + if not first_step: + fresh_interval = cache_dic["cal_threshold"] + else: + fresh_interval = cache_dic["fresh_threshold"] + + if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1): + current["type"] = "full" + cache_dic["cache_counter"] = 0 + current["activated_steps"].append(current["step"]) + force_scheduler(cache_dic, current) + + elif cache_dic["taylor_cache"]: + cache_dic["cache_counter"] += 1 + current["type"] = "Taylor" + + elif ( + cache_dic["cache_counter"] % 2 == 1 + ): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive + cache_dic["cache_counter"] += 1 + current["type"] = "ToCa" + # 'cache_noise' 'ToCa' 'FORA' + elif cache_dic["Delta-DiT"]: + cache_dic["cache_counter"] += 1 + current["type"] = "Delta-Cache" + else: + cache_dic["cache_counter"] += 1 + current["type"] = "ToCa" + + +def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor): + """ + Compute derivative approximation. + + :param cache_dic: Cache dictionary + :param current: Information of the current step + """ + difference_distance = ( + current["activated_steps"][-1] - current["activated_steps"][-2] + ) + + updated_taylor_factors = {} + updated_taylor_factors[0] = feature + + for i in range(cache_dic["max_order"]): + if ( + cache_dic["cache"][-1][current["stream"]][current["layer"]][ + current["module"] + ].get(i, None) + is not None + ) and (current["step"] > cache_dic["first_enhance"] - 2): + updated_taylor_factors[i + 1] = ( + updated_taylor_factors[i] + - cache_dic["cache"][-1][current["stream"]][current["layer"]][ + current["module"] + ][i] + ) / difference_distance + else: + break + + cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = ( + updated_taylor_factors + ) + + +def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor: + """ + Compute Taylor expansion error. + + :param cache_dic: Cache dictionary + :param current: Information of the current step + """ + x = current["step"] - current["activated_steps"][-1] + # x = current['t'] - current['activated_times'][-1] + output = 0 + + for i in range( + len( + cache_dic["cache"][-1][current["stream"]][current["layer"]][ + current["module"] + ] + ) + ): + output += ( + (1 / math.factorial(i)) + * cache_dic["cache"][-1][current["stream"]][current["layer"]][ + current["module"] + ][i] + * (x**i) + ) + + return output + + +def taylor_cache_init(cache_dic: Dict, current: Dict): + """ + Initialize Taylor cache and allocate storage for different-order derivatives in the Taylor cache. + + :param cache_dic: Cache dictionary + :param current: Information of the current step + """ + if (current["step"] == 0) and (cache_dic["taylor_cache"]): + cache_dic["cache"][-1][current["stream"]][current["layer"]][ + current["module"] + ] = {} + + +logger = logging.get_logger(__name__) + + +class OmniGen2TransformerBlock(nn.Module): + """ + Transformer block for OmniGen2 model. + + This block implements a transformer layer with: + - Multi-head attention with flash attention + - Feed-forward network with SwiGLU activation + - RMS normalization + - Optional modulation for conditional generation + + Args: + dim: Dimension of the input and output tensors + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for the feed-forward network dimension + norm_eps: Epsilon value for normalization layers + modulation: Whether to use modulation for conditional generation + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + """Initialize the transformer block.""" + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + try: + processor = OmniGen2AttnProcessorFlash2Varlen() + except ImportError: + processor = OmniGen2AttnProcessor() + + # Initialize attention layer + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=processor, + ) + + # Initialize feed-forward network + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + # Initialize normalization layers + if modulation: + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True + ) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) + + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.initialize_weights() + + def initialize_weights(self) -> None: + """ + Initialize the weights of the transformer block. + + Uses Xavier uniform initialization for linear layers and zero initialization for biases. + """ + nn.init.xavier_uniform_(self.attn.to_q.weight) + nn.init.xavier_uniform_(self.attn.to_k.weight) + nn.init.xavier_uniform_(self.attn.to_v.weight) + nn.init.xavier_uniform_(self.attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) + + if self.modulation: + nn.init.zeros_(self.norm1.linear.weight) + nn.init.zeros_(self.norm1.linear.bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the transformer block. + + Args: + hidden_states: Input hidden states tensor + attention_mask: Attention mask tensor + image_rotary_emb: Rotary embeddings for image tokens + temb: Optional timestep embedding tensor + + Returns: + torch.Tensor: Output hidden states after transformer block processing + """ + enable_taylorseer = getattr(self, "enable_taylorseer", False) + if enable_taylorseer: + if self.modulation: + if temb is None: + raise ValueError("temb must be provided when modulation is enabled") + + if self.current["type"] == "full": + self.current["module"] = "total" + taylor_cache_init(cache_dic=self.cache_dic, current=self.current) + + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1( + hidden_states, temb + ) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze( + 1 + ).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward( + self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + ) + hidden_states = hidden_states + gate_mlp.unsqueeze( + 1 + ).tanh() * self.ffn_norm2(mlp_output) + + derivative_approximation( + cache_dic=self.cache_dic, + current=self.current, + feature=hidden_states, + ) + + elif self.current["type"] == "Taylor": + self.current["module"] = "total" + hidden_states = taylor_formula( + cache_dic=self.cache_dic, current=self.current + ) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + else: + if self.modulation: + if temb is None: + raise ValueError("temb must be provided when modulation is enabled") + + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1( + hidden_states, temb + ) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze( + 1 + ).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward( + self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + ) + hidden_states = hidden_states + gate_mlp.unsqueeze( + 1 + ).tanh() * self.ffn_norm2(mlp_output) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class OmniGen2Transformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin +): + """ + OmniGen2 Transformer 3D Model (modified to output frame sequences). + + A transformer-based diffusion model for image generation with: + - Patch-based image processing + - Rotary position embeddings + - Multi-head attention + - Conditional generation support + + Args: + patch_size: Size of image patches + in_channels: Number of input channels + out_channels: Number of output channels (defaults to in_channels) + hidden_size: Size of hidden layers + num_layers: Number of transformer layers + num_refiner_layers: Number of refiner layers + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for feed-forward network dimension + norm_eps: Epsilon value for normalization layers + axes_dim_rope: Dimensions for rotary position embeddings + axes_lens: Lengths for rotary position embeddings + text_feat_dim: Dimension of text features + timestep_scale: Scale factor for timestep embeddings + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Omnigen2TransformerBlock"] + _skip_layerwise_casting_patterns = ["x_embedder", "norm"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 2304, + num_layers: int = 26, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), + axes_lens: Tuple[int, int, int] = (300, 512, 512), + text_feat_dim: int = 1024, + timestep_scale: float = 1.0, + ) -> None: + """Initialize the OmniGen2 transformer model.""" + super().__init__() + + # Validate configuration + if (hidden_size // num_attention_heads) != sum(axes_dim_rope): + raise ValueError( + f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " + f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" + ) + + self.out_channels = out_channels or in_channels + + # Initialize embeddings + self.rope_embedder = OmniGen2RotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.ref_image_patch_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + text_feat_dim=text_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale, + ) + + # Initialize transformer blocks + self.noise_refiner = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.ref_image_refiner = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False, + ) + for _ in range(num_refiner_layers) + ] + ) + + # 3. Transformer blocks + self.layers = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + + # Add learnable embeddings to distinguish different images + self.image_index_embedding = nn.Parameter( + torch.randn(5, hidden_size) + ) # support max 5 ref images + + self.gradient_checkpointing = False + + self.initialize_weights() + + # TeaCache settings + self.enable_teacache = False + self.teacache_rel_l1_thresh = 0.05 + self.teacache_params = TeaCacheParams() + + coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487] + self.rescale_func = np.poly1d(coefficients) + + def initialize_weights(self) -> None: + """ + Initialize the weights of the model. + + Uses Xavier uniform initialization for linear layers. + """ + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) + nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) + + nn.init.zeros_(self.norm_out.linear_1.weight) + nn.init.zeros_(self.norm_out.linear_1.bias) + nn.init.zeros_(self.norm_out.linear_2.weight) + nn.init.zeros_(self.norm_out.linear_2.bias) + + nn.init.normal_(self.image_index_embedding, std=0.02) + + def img_patch_embed_and_refine( + self, + hidden_states, + ref_image_hidden_states, + padded_img_mask, + padded_ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ): + batch_size = len(hidden_states) + if isinstance(l_effective_img_len[0], list): + l_effective_img_len_summed = [sum(ln) for ln in l_effective_img_len] + else: + l_effective_img_len_summed = l_effective_img_len + max_combined_img_len = max( + [ + img_len + sum(ref_img_len) + for img_len, ref_img_len in zip( + l_effective_img_len_summed, l_effective_ref_img_len + ) + ] + ) + + hidden_states = self.x_embedder(hidden_states) + ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) + + for i in range(batch_size): + shift = 0 + for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): + ref_image_hidden_states[i, shift : shift + ref_img_len, :] = ( + ref_image_hidden_states[i, shift : shift + ref_img_len, :] + + self.image_index_embedding[j] + ) + shift += ref_img_len + + for layer in self.noise_refiner: + hidden_states = layer( + hidden_states, padded_img_mask, noise_rotary_emb, temb + ) + + flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) + num_ref_images = len(flat_l_effective_ref_img_len) + max_ref_img_len = max(flat_l_effective_ref_img_len) + + batch_ref_img_mask = ref_image_hidden_states.new_zeros( + num_ref_images, max_ref_img_len, dtype=torch.bool + ) + batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros( + num_ref_images, max_ref_img_len, self.config.hidden_size + ) + batch_ref_img_rotary_emb = hidden_states.new_zeros( + num_ref_images, + max_ref_img_len, + ref_img_rotary_emb.shape[-1], + dtype=ref_img_rotary_emb.dtype, + ) + batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) + + # sequence of ref imgs to batch + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + batch_ref_img_mask[idx, :ref_img_len] = True + batch_ref_image_hidden_states[idx, :ref_img_len] = ( + ref_image_hidden_states[i, shift : shift + ref_img_len] + ) + batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[ + i, shift : shift + ref_img_len + ] + batch_temb[idx] = temb[i] + shift += ref_img_len + idx += 1 + + # refine ref imgs separately + for layer in self.ref_image_refiner: + batch_ref_image_hidden_states = layer( + batch_ref_image_hidden_states, + batch_ref_img_mask, + batch_ref_img_rotary_emb, + batch_temb, + ) + + # batch of ref imgs to sequence + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + ref_image_hidden_states[i, shift : shift + ref_img_len] = ( + batch_ref_image_hidden_states[idx, :ref_img_len] + ) + shift += ref_img_len + idx += 1 + + combined_img_hidden_states = hidden_states.new_zeros( + batch_size, max_combined_img_len, self.config.hidden_size + ) + for i, (ref_img_len, img_len) in enumerate( + zip(l_effective_ref_img_len, l_effective_img_len_summed) + ): + combined_img_hidden_states[i, : sum(ref_img_len)] = ref_image_hidden_states[ + i, : sum(ref_img_len) + ] + combined_img_hidden_states[ + i, sum(ref_img_len) : sum(ref_img_len) + img_len + ] = hidden_states[i, :img_len] + + return combined_img_hidden_states + + def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): + batch_size = len(hidden_states) + p = self.config.patch_size + device = hidden_states[0].device + + if len(hidden_states[0].shape) == 3: + img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] + l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] + else: + img_sizes = [ + [(img.size(1), img.size(2)) for img in imgs] for imgs in hidden_states + ] + l_effective_img_len = [ + [(H // p) * (W // p) for (H, W) in _img_sizes] + for _img_sizes in img_sizes + ] + + if ref_image_hidden_states is not None: + ref_img_sizes = [ + [(img.size(1), img.size(2)) for img in imgs] + if imgs is not None + else None + for imgs in ref_image_hidden_states + ] + l_effective_ref_img_len = [ + [ + (ref_img_size[0] // p) * (ref_img_size[1] // p) + for ref_img_size in _ref_img_sizes + ] + if _ref_img_sizes is not None + else [0] + for _ref_img_sizes in ref_img_sizes + ] + else: + ref_img_sizes = [None for _ in range(batch_size)] + l_effective_ref_img_len = [[0] for _ in range(batch_size)] + + max_ref_img_len = max( + [sum(ref_img_len) for ref_img_len in l_effective_ref_img_len] + ) + if len(hidden_states[0].shape) == 4: + max_img_len = max([sum(img_len) for img_len in l_effective_img_len]) + else: + max_img_len = max(l_effective_img_len) + + # ref image patch embeddings + flat_ref_img_hidden_states = [] + for i in range(batch_size): + if ref_img_sizes[i] is not None: + imgs = [] + for ref_img in ref_image_hidden_states[i]: + C, H, W = ref_img.size() + ref_img = rearrange( + ref_img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p + ) + imgs.append(ref_img) + + img = torch.cat(imgs, dim=0) + flat_ref_img_hidden_states.append(img) + else: + flat_ref_img_hidden_states.append(None) + + # image patch embeddings + flat_hidden_states = [] + if len(hidden_states[0].shape) == 4: # New case + for i in range(batch_size): + # Process each time step and concatenate + batch_img_patches = [] + for img in hidden_states[i]: + C, H, W = img.size() + img = rearrange( + img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p + ) + batch_img_patches.append(img) + # Concatenate patches for the current batch item across time + flat_hidden_states.append(torch.cat(batch_img_patches, dim=0)) + else: # Default + for i in range(batch_size): + img = hidden_states[i] + C, H, W = img.size() + + img = rearrange(img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p) + flat_hidden_states.append(img) + + padded_ref_img_hidden_states = torch.zeros( + batch_size, + max_ref_img_len, + flat_hidden_states[0].shape[-1], + device=device, + dtype=flat_hidden_states[0].dtype, + ) + padded_ref_img_mask = torch.zeros( + batch_size, max_ref_img_len, dtype=torch.bool, device=device + ) + for i in range(batch_size): + if ref_img_sizes[i] is not None: + padded_ref_img_hidden_states[i, : sum(l_effective_ref_img_len[i])] = ( + flat_ref_img_hidden_states[i] + ) + padded_ref_img_mask[i, : sum(l_effective_ref_img_len[i])] = True + + padded_hidden_states = torch.zeros( + batch_size, + max_img_len, + flat_hidden_states[0].shape[-1], + device=device, + dtype=flat_hidden_states[0].dtype, + ) + padded_img_mask = torch.zeros( + batch_size, max_img_len, dtype=torch.bool, device=device + ) + for i in range(batch_size): + if len(hidden_states[0].shape) == 4: # New case + padded_hidden_states[i, : sum(l_effective_img_len[i])] = ( + flat_hidden_states[i] + ) + padded_img_mask[i, : sum(l_effective_img_len[i])] = True + else: + padded_hidden_states[i, : l_effective_img_len[i]] = flat_hidden_states[ + i + ] + padded_img_mask[i, : l_effective_img_len[i]] = True + + return ( + padded_hidden_states, + padded_ref_img_hidden_states, + padded_img_mask, + padded_ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) + + def forward( + self, + hidden_states: Union[torch.Tensor, List[torch.Tensor]], + timestep: torch.Tensor, + text_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + text_attention_mask: torch.Tensor, + ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + enable_taylorseer = getattr(self, "enable_taylorseer", False) + if enable_taylorseer: + cal_type(self.cache_dic, self.current) + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if ( + attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # 1. Condition, positional & patch embedding + batch_size = len(hidden_states) + is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) + + if is_hidden_states_tensor: + assert hidden_states.ndim == 4 + hidden_states = [_hidden_states for _hidden_states in hidden_states] + + device = hidden_states[0].device + + temb, text_hidden_states = self.time_caption_embed( + timestep, text_hidden_states, hidden_states[0].dtype + ) + + ( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + ( + context_rotary_emb, + ref_img_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = self.rope_embedder( + freqs_cis, + text_attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ) + + # 2. Context refinement + for layer in self.context_refiner: + text_hidden_states = layer( + text_hidden_states, text_attention_mask, context_rotary_emb + ) + + combined_img_hidden_states = self.img_patch_embed_and_refine( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ) + + # 3. Joint Transformer blocks + max_seq_len = max(seq_lengths) + + attention_mask = hidden_states.new_zeros( + batch_size, max_seq_len, dtype=torch.bool + ) + joint_hidden_states = hidden_states.new_zeros( + batch_size, max_seq_len, self.config.hidden_size + ) + for i, (encoder_seq_len, seq_len) in enumerate( + zip(encoder_seq_lengths, seq_lengths) + ): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[ + i, :encoder_seq_len + ] + joint_hidden_states[i, encoder_seq_len:seq_len] = ( + combined_img_hidden_states[i, : seq_len - encoder_seq_len] + ) + + hidden_states = joint_hidden_states + + if self.enable_teacache: + teacache_hidden_states = hidden_states.clone() + teacache_temb = temb.clone() + modulated_inp, _, _, _ = self.layers[0].norm1( + teacache_hidden_states, teacache_temb + ) + if self.teacache_params.is_first_or_last_step: + should_calc = True + self.teacache_params.accumulated_rel_l1_distance = 0 + else: + self.teacache_params.accumulated_rel_l1_distance += self.rescale_func( + ( + (modulated_inp - self.teacache_params.previous_modulated_inp) + .abs() + .mean() + / self.teacache_params.previous_modulated_inp.abs().mean() + ) + .cpu() + .item() + ) + if ( + self.teacache_params.accumulated_rel_l1_distance + < self.teacache_rel_l1_thresh + ): + should_calc = False + else: + should_calc = True + self.teacache_params.accumulated_rel_l1_distance = 0 + self.teacache_params.previous_modulated_inp = modulated_inp + + if self.enable_teacache: + if not should_calc: + hidden_states += self.teacache_params.previous_residual + else: + ori_hidden_states = hidden_states.clone() + for layer_idx, layer in enumerate(self.layers): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, attention_mask, rotary_emb, temb + ) + else: + hidden_states = layer( + hidden_states, attention_mask, rotary_emb, temb + ) + self.teacache_params.previous_residual = ( + hidden_states - ori_hidden_states + ) + else: + if enable_taylorseer: + self.current["stream"] = "layers_stream" + + for layer_idx, layer in enumerate(self.layers): + if enable_taylorseer: + layer.current = self.current + layer.cache_dic = self.cache_dic + layer.enable_taylorseer = True + self.current["layer"] = layer_idx + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, attention_mask, rotary_emb, temb + ) + else: + hidden_states = layer( + hidden_states, attention_mask, rotary_emb, temb + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + + p = self.config.patch_size + output = [] + + for i, (img_size, img_len, seq_len) in enumerate( + zip(img_sizes, l_effective_img_len, seq_lengths) + ): + if isinstance(img_len, list): + batch_output = [] + cur_st = seq_len - sum(img_len) + for j in range(len(img_len)): + height, width = img_size[j] + cur_len = img_len[j] + batch_output.append( + rearrange( + hidden_states[i][cur_st : cur_st + cur_len], + "(h w) (p1 p2 c) -> c (h p1) (w p2)", + h=height // p, + w=width // p, + p1=p, + p2=p, + ) + ) + cur_st += cur_len + output.append(torch.stack(batch_output, dim=0)) + + else: + height, width = img_size + output.append( + rearrange( + hidden_states[i][seq_len - img_len : seq_len], + "(h w) (p1 p2 c) -> c (h p1) (w p2)", + h=height // p, + w=width // p, + p1=p, + p2=p, + ) + ) + if is_hidden_states_tensor: + output = torch.stack(output, dim=0) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if enable_taylorseer: + self.current["step"] += 1 + + if not return_dict: + return output + return Transformer2DModelOutput(sample=output)