import warnings import itertools from typing import Any, Dict, List, Optional, Tuple, Union from dataclasses import dataclass import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.attention_processor import Attention from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.embeddings import get_1d_rotary_pos_embed from diffusers.models.activations import get_activation from diffusers.models.embeddings import Timesteps import importlib.util import sys # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata def _is_package_available(pkg_name: str): pkg_exists = importlib.util.find_spec(pkg_name) is not None pkg_version = "N/A" if pkg_exists: try: pkg_version = importlib_metadata.version(pkg_name) except (ImportError, importlib_metadata.PackageNotFoundError): pkg_exists = False return pkg_exists, pkg_version _triton_available, _triton_version = _is_package_available("triton") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") def is_triton_available(): return _triton_available def is_flash_attn_available(): return _flash_attn_available if is_flash_attn_available(): from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input else: warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance") if is_triton_available(): # from ...ops.triton.layer_norm import RMSNorm import triton import triton.language as tl from typing import Callable def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): def decorator(*args, **kwargs): if cuda_amp_deprecated: kwargs["device_type"] = "cuda" return dec(*args, **kwargs) return decorator if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] deprecated = True from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] else: deprecated = False from torch.cuda.amp import custom_fwd, custom_bwd custom_fwd = custom_amp_decorator(custom_fwd, deprecated) custom_bwd = custom_amp_decorator(custom_bwd, deprecated) def triton_autotune_configs(): # Return configs with a valid warp count for the current device configs=[] # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 max_threads_per_block=1024 # Default to warp size 32 if not defined by device warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit warp_count=1 while warp_count*warp_size <= max_threads_per_block: configs.append(triton.Config({}, num_warps=warp_count)) warp_count*=2 return configs @triton.autotune( configs=triton_autotune_configs(), key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights B, # pointer to the biases RESIDUAL, # pointer to the residual X1, W1, B1, Y1, RESIDUAL_OUT, # pointer to the residual ROWSCALE, SEEDS, # Dropout seeds for each row DROPOUT_MASK, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_res_row, stride_res_out_row, stride_x1_row, stride_y1_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, # Dropout probability zero_centered_weight, # If true, add 1.0 to the weight IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, STORE_DROPOUT_MASK: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_X1: tl.constexpr, HAS_W1: tl.constexpr, HAS_B1: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) X += row * stride_x_row Y += row * stride_y_row if HAS_RESIDUAL: RESIDUAL += row * stride_res_row if STORE_RESIDUAL_OUT: RESIDUAL_OUT += row * stride_res_out_row if HAS_X1: X1 += row * stride_x1_row if HAS_W1: Y1 += row * stride_y1_row # Compute mean and variance cols = tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) x *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) if HAS_X1: x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) x1 *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = ( tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) x += x1 if HAS_RESIDUAL: residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) x += residual if STORE_RESIDUAL_OUT: tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) if zero_centered_weight: w += 1.0 if HAS_BIAS: b = tl.load(B + cols, mask=mask).to(tl.float32) x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd y = x_hat * w + b if HAS_BIAS else x_hat * w # Write output tl.store(Y + cols, y, mask=mask) if HAS_W1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) if zero_centered_weight: w1 += 1.0 if HAS_B1: b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 tl.store(Y1 + cols, y1, mask=mask) def _layer_norm_fwd( x, weight, bias, eps, residual=None, x1=None, weight1=None, bias1=None, dropout_p=0.0, rowscale=None, out_dtype=None, residual_dtype=None, zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, out=None, residual_out=None ): if residual is not None: residual_dtype = residual.dtype M, N = x.shape assert x.stride(-1) == 1 if residual is not None: assert residual.stride(-1) == 1 assert residual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if x1 is not None: assert x1.shape == x.shape assert rowscale is None assert x1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output if out is None: out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) else: assert out.shape == x.shape assert out.stride(-1) == 1 if weight1 is not None: y1 = torch.empty_like(out) assert y1.stride(-1) == 1 else: y1 = None if ( residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None ): if residual_out is None: residual_out = torch.empty( M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype ) else: assert residual_out.shape == x.shape assert residual_out.stride(-1) == 1 else: residual_out = None mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((M,), dtype=torch.float32, device=x.device) if dropout_p > 0.0: seeds = torch.randint( 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 ) else: seeds = None if return_dropout_mask and dropout_p > 0.0: dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) else: dropout_mask = None # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") with torch.cuda.device(x.device.index): _layer_norm_fwd_1pass_kernel[(M,)]( x, out, weight, bias, residual, x1, weight1, bias1, y1, residual_out, rowscale, seeds, dropout_mask, mean, rstd, x.stride(0), out.stride(0), residual.stride(0) if residual is not None else 0, residual_out.stride(0) if residual_out is not None else 0, x1.stride(0) if x1 is not None else 0, y1.stride(0) if y1 is not None else 0, M, N, eps, dropout_p, zero_centered_weight, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None, bias is not None, dropout_p > 0.0, dropout_mask is not None, rowscale is not None, ) # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 if dropout_mask is not None and x1 is not None: dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) else: dropout_mask1 = None return ( out, y1, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask, dropout_mask1, ) @triton.autotune( configs=triton_autotune_configs(), key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.jit def _layer_norm_bwd_kernel( X, # pointer to the input W, # pointer to the weights B, # pointer to the biases Y, # pointer to the output to be recomputed DY, # pointer to the output gradient DX, # pointer to the input gradient DW, # pointer to the partial sum of weights gradient DB, # pointer to the partial sum of biases gradient DRESIDUAL, W1, DY1, DX1, DW1, DB1, DRESIDUAL_IN, ROWSCALE, SEEDS, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_dy_row, stride_dx_row, stride_dres_row, stride_dy1_row, stride_dx1_row, stride_dres_in_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, zero_centered_weight, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_DY1: tl.constexpr, HAS_DX1: tl.constexpr, HAS_B1: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, ): # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program # Do not early exit if row_start >= M, because we need to write DW and DB cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row if HAS_DRESIDUAL: DRESIDUAL += row_start * stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += row_start * stride_dres_in_row DY += row_start * stride_dy_row DX += row_start * stride_dx_row if HAS_DY1: DY1 += row_start * stride_dy1_row if HAS_DX1: DX1 += row_start * stride_dx1_row if RECOMPUTE_OUTPUT: Y += row_start * stride_y_row w = tl.load(W + cols, mask=mask).to(tl.float32) if zero_centered_weight: w += 1.0 if RECOMPUTE_OUTPUT and HAS_BIAS: b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) if HAS_DY1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) if zero_centered_weight: w1 += 1.0 dw = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_BIAS: db = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_DY1: dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_B1: db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) row_end = min((row_block_id + 1) * rows_per_program, M) for row in range(row_start, row_end): # Load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) if HAS_DY1: dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) if not IS_RMS_NORM: mean = tl.load(Mean + row) rstd = tl.load(Rstd + row) # Compute dx xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd xhat = tl.where(mask, xhat, 0.0) if RECOMPUTE_OUTPUT: y = xhat * w + b if HAS_BIAS else xhat * w tl.store(Y + cols, y, mask=mask) wdy = w * dy dw += dy * xhat if HAS_BIAS: db += dy if HAS_DY1: wdy += w1 * dy1 dw1 += dy1 * xhat if HAS_B1: db1 += dy1 if not IS_RMS_NORM: c1 = tl.sum(xhat * wdy, axis=0) / N c2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat * c1 + c2)) * rstd else: c1 = tl.sum(xhat * wdy, axis=0) / N dx = (wdy - xhat * c1) * rstd if HAS_DRESIDUAL: dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) dx += dres # Write dx if STORE_DRESIDUAL: tl.store(DRESIDUAL_IN + cols, dx, mask=mask) if HAS_DX1: if HAS_DROPOUT: keep_mask = ( tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) else: dx1 = dx tl.store(DX1 + cols, dx1, mask=mask) if HAS_DROPOUT: keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) dx *= rowscale tl.store(DX + cols, dx, mask=mask) X += stride_x_row if HAS_DRESIDUAL: DRESIDUAL += stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += stride_dres_in_row if RECOMPUTE_OUTPUT: Y += stride_y_row DY += stride_dy_row DX += stride_dx_row if HAS_DY1: DY1 += stride_dy1_row if HAS_DX1: DX1 += stride_dx1_row tl.store(DW + row_block_id * N + cols, dw, mask=mask) if HAS_BIAS: tl.store(DB + row_block_id * N + cols, db, mask=mask) if HAS_DY1: tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) if HAS_B1: tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) def _layer_norm_bwd( dy, x, weight, bias, eps, mean, rstd, dresidual=None, dy1=None, weight1=None, bias1=None, seeds=None, dropout_p=0.0, rowscale=None, has_residual=False, has_x1=False, zero_centered_weight=False, is_rms_norm=False, x_dtype=None, recompute_output=False, ): M, N = x.shape assert x.stride(-1) == 1 assert dy.stride(-1) == 1 assert dy.shape == (M, N) if dresidual is not None: assert dresidual.stride(-1) == 1 assert dresidual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if dy1 is not None: assert weight1 is not None assert dy1.shape == dy.shape assert dy1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if seeds is not None: assert seeds.is_contiguous() assert seeds.shape == (M if not has_x1 else M * 2,) if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output dx = ( torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) ) dresidual_in = ( torch.empty_like(x) if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) else None ) dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None if recompute_output: assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the # latency of the gmem reads/writes, but will increase the time of summing up dw / db. sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) _db = ( torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None ) _dw1 = torch.empty_like(_dw) if weight1 is not None else None _db1 = torch.empty_like(_db) if bias1 is not None else None rows_per_program = math.ceil(M / sm_count) grid = (sm_count,) with torch.cuda.device(x.device.index): _layer_norm_bwd_kernel[grid]( x, weight, bias, y, dy, dx, _dw, _db, dresidual, weight1, dy1, dx1, _dw1, _db1, dresidual_in, rowscale, seeds, mean, rstd, x.stride(0), 0 if not recompute_output else y.stride(0), dy.stride(0), dx.stride(0), dresidual.stride(0) if dresidual is not None else 0, dy1.stride(0) if dy1 is not None else 0, dx1.stride(0) if dx1 is not None else 0, dresidual_in.stride(0) if dresidual_in is not None else 0, M, N, eps, dropout_p, zero_centered_weight, rows_per_program, is_rms_norm, BLOCK_N, dresidual is not None, dresidual_in is not None, bias is not None, dropout_p > 0.0, ) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None # Don't need to compute dresidual_in separately in this case if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: dresidual_in = dx if has_x1 and dropout_p == 0.0: dx1 = dx return ( (dx, dw, db, dresidual_in, dx1, dw1, db1) if not recompute_output else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) ) class LayerNormFn(torch.autograd.Function): @staticmethod def forward( ctx, x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, out=None, residual_out=None ): x_shape_og = x.shape # Check for zero sequence length if x.numel() == 0: ctx.zero_seq_length = True # Only save minimal required tensors for backward # ctx.save_for_backward(weight, bias, weight1, bias1) ctx.x_shape_og = x_shape_og ctx.weight_shape = weight.shape ctx.weight_dtype = weight.dtype ctx.weight_device = weight.device ctx.has_bias = bias is not None ctx.bias_shape = bias.shape if bias is not None else None ctx.bias_dtype = bias.dtype if bias is not None else None ctx.bias_device = bias.device if bias is not None else None ctx.has_weight1 = weight1 is not None ctx.weight1_shape = weight1.shape if weight1 is not None else None ctx.weight1_dtype = weight1.dtype if weight1 is not None else None ctx.weight1_device = weight1.device if weight1 is not None else None ctx.has_bias1 = bias1 is not None ctx.bias1_shape = bias1.shape if bias1 is not None else None ctx.bias1_dtype = bias1.dtype if bias1 is not None else None ctx.bias1_device = bias1.device if bias1 is not None else None ctx.has_residual = residual is not None ctx.has_x1 = x1 is not None ctx.dropout_p = dropout_p # Handle output tensors with correct dtype y = x # Preserve input tensor properties y1 = torch.empty_like(x) if x1 is not None else None # Only create residual_out if prenorm is True residual_out = torch.empty(x.shape, dtype=torch.float32 if residual_in_fp32 else x.dtype, device=x.device) if prenorm else None # Handle dropout masks dropout_mask = None dropout_mask1 = None if return_dropout_mask: dropout_mask = torch.empty_like(x, dtype=torch.uint8) if x1 is not None: dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) # Return based on configuration if not return_dropout_mask: if weight1 is None: return y if not prenorm else (y, residual_out) else: return (y, y1) if not prenorm else (y, y1, residual_out) else: if weight1 is None: return ((y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1)) else: return ((y, y1, dropout_mask, dropout_mask1) if not prenorm else (y, y1, residual_out, dropout_mask, dropout_mask1)) ctx.zero_seq_length = False # reshape input data into 2D tensor x = x.reshape(-1, x.shape[-1]) if x.stride(-1) != 1: x = x.contiguous() if residual is not None: assert residual.shape == x_shape_og residual = residual.reshape(-1, residual.shape[-1]) if residual.stride(-1) != 1: residual = residual.contiguous() if x1 is not None: assert x1.shape == x_shape_og assert rowscale is None, "rowscale is not supported with parallel LayerNorm" x1 = x1.reshape(-1, x1.shape[-1]) if x1.stride(-1) != 1: x1 = x1.contiguous() weight = weight.contiguous() if bias is not None: bias = bias.contiguous() if weight1 is not None: weight1 = weight1.contiguous() if bias1 is not None: bias1 = bias1.contiguous() if rowscale is not None: rowscale = rowscale.reshape(-1).contiguous() residual_dtype = ( residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) ) if out is not None: out = out.reshape(-1, out.shape[-1]) if residual_out is not None: residual_out = residual_out.reshape(-1, residual_out.shape[-1]) y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( x, weight, bias, eps, residual, x1, weight1, bias1, dropout_p=dropout_p, rowscale=rowscale, residual_dtype=residual_dtype, zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=return_dropout_mask, out=out, residual_out=residual_out ) ctx.save_for_backward( residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd ) ctx.x_shape_og = x_shape_og ctx.eps = eps ctx.dropout_p = dropout_p ctx.is_rms_norm = is_rms_norm ctx.has_residual = residual is not None ctx.has_x1 = x1 is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype ctx.zero_centered_weight = zero_centered_weight y = y.reshape(x_shape_og) y1 = y1.reshape(x_shape_og) if y1 is not None else None residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None if not return_dropout_mask: if weight1 is None: return y if not prenorm else (y, residual_out) else: return (y, y1) if not prenorm else (y, y1, residual_out) else: if weight1 is None: return ( (y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1) ) else: return ( (y, y1, dropout_mask, dropout_mask1) if not prenorm else (y, y1, residual_out, dropout_mask, dropout_mask1) ) @staticmethod def backward(ctx, dy, *args): if ctx.zero_seq_length: return ( torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device), torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None, torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None, torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None, torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None, torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None, None, None, None, None, None, None, None, None, None, None, ) x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors dy = dy.reshape(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() assert dy.shape == x.shape if weight1 is not None: dy1, args = args[0], args[1:] dy1 = dy1.reshape(-1, dy1.shape[-1]) if dy1.stride(-1) != 1: dy1 = dy1.contiguous() assert dy1.shape == x.shape else: dy1 = None if ctx.prenorm: dresidual = args[0] dresidual = dresidual.reshape(-1, dresidual.shape[-1]) if dresidual.stride(-1) != 1: dresidual = dresidual.contiguous() assert dresidual.shape == x.shape else: dresidual = None dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( dy, x, weight, bias, ctx.eps, mean, rstd, dresidual, dy1, weight1, bias1, seeds, ctx.dropout_p, rowscale, ctx.has_residual, ctx.has_x1, ctx.zero_centered_weight, ctx.is_rms_norm, x_dtype=ctx.x_dtype, ) return ( dx.reshape(ctx.x_shape_og), dw, db, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, dw1, db1, None, None, None, None, None, None, None, None, None, None, ) def rms_norm_fn( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, return_dropout_mask=False, out=None, residual_out=None ): return LayerNormFn.apply( x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, zero_centered_weight, True, return_dropout_mask, out, residual_out ) class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps if dropout_p > 0.0: self.drop = torch.nn.Dropout(dropout_p) else: self.drop = None self.zero_centered_weight = zero_centered_weight self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): if not self.zero_centered_weight: torch.nn.init.ones_(self.weight) else: torch.nn.init.zeros_(self.weight) def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): return rms_norm_fn( x, self.weight, self.bias, residual=residual, eps=self.eps, dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, prenorm=prenorm, residual_in_fp32=residual_in_fp32, zero_centered_weight=self.zero_centered_weight, ) else: from torch.nn import RMSNorm warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance") def swiglu(x, y): return F.silu(x.float(), inplace=False).to(x.dtype) * y logger = logging.get_logger(__name__) @dataclass class TeaCacheParams: previous_residual: Optional[torch.Tensor] = None previous_modulated_inp: Optional[torch.Tensor] = None accumulated_rel_l1_distance: float = 0 is_first_or_last_step: bool = False class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) self.initialize_weights() def initialize_weights(self): nn.init.normal_(self.linear_1.weight, std=0.02) nn.init.zeros_(self.linear_1.bias) nn.init.normal_(self.linear_2.weight, std=0.02) nn.init.zeros_(self.linear_2.bias) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample def apply_rotary_emb( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (`torch.Tensor`): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ if use_real: cos, sin = freqs_cis # [S, D] cos = cos[None, None] sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: # used for lumina # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) class ThinkGenRotaryPosEmbed(nn.Module): def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2): super().__init__() self.theta = theta self.axes_dim = axes_dim self.axes_lens = axes_lens self.patch_size = patch_size @staticmethod def get_freqs_cis(axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int) -> List[torch.Tensor]: freqs_cis = [] freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) freqs_cis.append(emb) return freqs_cis def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: device = ids.device if ids.device.type == "mps": ids = ids.to("cpu") result = [] for i in range(len(self.axes_dim)): freqs = freqs_cis[i].to(ids.device) index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) return torch.cat(result, dim=-1).to(device) def forward( self, freqs_cis, attention_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device ): batch_size = len(attention_mask) p = self.patch_size encoder_seq_len = attention_mask.shape[1] l_effective_cap_len = attention_mask.sum(dim=1).tolist() seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] max_seq_len = max(seq_lengths) max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) max_img_len = max(l_effective_img_len) # Create position IDs position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): # add text position ids position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") pe_shift = cap_seq_len pe_shift_len = cap_seq_len if ref_img_sizes[i] is not None: for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): H, W = ref_img_size ref_H_tokens, ref_W_tokens = H // p, W // p assert ref_H_tokens * ref_W_tokens == ref_img_len # add image position ids row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids pe_shift += max(ref_H_tokens, ref_W_tokens) pe_shift_len += ref_img_len H, W = img_sizes[i] H_tokens, W_tokens = H // p, W // p assert H_tokens * W_tokens == l_effective_img_len[i] row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() assert pe_shift_len + l_effective_img_len[i] == seq_len position_ids[i, pe_shift_len: seq_len, 0] = pe_shift position_ids[i, pe_shift_len: seq_len, 1] = row_ids position_ids[i, pe_shift_len: seq_len, 2] = col_ids # Get combined rotary embeddings freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) # create separate rotary embeddings for captions and images cap_freqs_cis = torch.zeros( batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) ref_img_freqs_cis = torch.zeros( batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) img_freqs_cis = torch.zeros( batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)] img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len] return ( cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths, ) class LuminaRMSNormZero(nn.Module): """ Norm layer adaptive RMS normalization zero. Parameters: embedding_dim (`int`): The size of each embedding vector. """ def __init__( self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool, ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear( min(embedding_dim, 1024), 4 * embedding_dim, bias=True, ) self.norm = RMSNorm(embedding_dim, eps=norm_eps) def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) return x, gate_msa, scale_mlp, gate_mlp class LuminaLayerNormContinuous(nn.Module): def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters # because the output is immediately scaled and shifted by the projected conditioning embeddings. # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. # However, this is how it was implemented in the original code, and it's rather likely you should # set `elementwise_affine` to False. elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", out_dim: Optional[int] = None, ): super().__init__() # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) elif norm_type == "rms_norm": self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") self.linear_2 = None if out_dim is not None: self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) def forward( self, x: torch.Tensor, conditioning_embedding: torch.Tensor, ) -> torch.Tensor: # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) scale = emb x = self.norm(x) * (1 + scale)[:, None, :] if self.linear_2 is not None: x = self.linear_2(x) return x class LuminaFeedForward(nn.Module): r""" A feed-forward layer. Parameters: hidden_size (`int`): The dimensionality of the hidden layers in the model. This parameter determines the width of the model's hidden representations. intermediate_size (`int`): The intermediate dimension of the feedforward layer. multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple of this value. ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden dimension. Defaults to None. """ def __init__( self, dim: int, inner_dim: int, multiple_of: Optional[int] = 256, ffn_dim_multiplier: Optional[float] = None, ): super().__init__() self.swiglu = swiglu # custom hidden_size factor multiplier if ffn_dim_multiplier is not None: inner_dim = int(ffn_dim_multiplier * inner_dim) inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) self.linear_1 = nn.Linear( dim, inner_dim, bias=False, ) self.linear_2 = nn.Linear( inner_dim, dim, bias=False, ) self.linear_3 = nn.Linear( dim, inner_dim, bias=False, ) def forward(self, x): h1, h2 = self.linear_1(x), self.linear_3(x) return self.linear_2(self.swiglu(h1, h2)) class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__( self, hidden_size: int = 4096, text_feat_dim: int = 204800, # 2048 frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, ) -> None: super().__init__() self.time_proj = Timesteps( num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale ) self.timestep_embedder = TimestepEmbedding( in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) ) self.caption_embedder = nn.Sequential( RMSNorm(text_feat_dim*2, eps=norm_eps), nn.Linear(text_feat_dim*2, hidden_size, bias=True), ) self._initialize_weights() def _initialize_weights(self): for name, module in self.caption_embedder.named_modules(): if hasattr(module, 'weight') and module.weight is not None: nn.init.trunc_normal_(module.weight, std=0.02) print(name, "a") if hasattr(module, 'bias') and module.bias is not None: nn.init.zeros_(module.bias) print(name, "b") print("init caption_embedder done") def forward( self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: timestep_proj = self.time_proj(timestep).to(dtype=dtype) time_embed = self.timestep_embedder(timestep_proj) caption_embed = self.caption_embedder(text_hidden_states) return time_embed, caption_embed class ThinkGenAttnProcessor: """ Processor for implementing scaled dot-product attention with flash attention and variable length sequences. This processor is optimized for PyTorch 2.0 and implements: - Flash attention with variable length sequences - Rotary position embeddings (RoPE) - Query-Key normalization - Proportional attention scaling Args: None Raises: ImportError: If PyTorch version is less than 2.0 """ def __init__(self) -> None: """Initialize the attention processor.""" if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "ThinkGenAttnProcessorFlash2Varlen requires PyTorch 2.0. " "Please upgrade PyTorch to version 2.0 or later." ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, base_sequence_length: Optional[int] = None, ) -> torch.Tensor: """ Process attention computation with flash attention. Args: attn: Attention module hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) encoder_hidden_states: Encoder hidden states tensor attention_mask: Optional attention mask tensor image_rotary_emb: Optional rotary embeddings for image tokens base_sequence_length: Optional base sequence length for proportional attention Returns: torch.Tensor: Processed hidden states after attention computation """ batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query_dim = query.shape[-1] inner_dim = key.shape[-1] head_dim = query_dim // attn.heads dtype = query.dtype # Get key-value heads kv_heads = inner_dim // head_dim # Reshape tensors for attention computation query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, kv_heads, head_dim) value = value.view(batch_size, -1, kv_heads, head_dim) # Apply Query-Key normalization if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply Rotary Position Embeddings if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, use_real=False) key = apply_rotary_emb(key, image_rotary_emb, use_real=False) query, key = query.to(dtype), key.to(dtype) # Calculate attention scale if base_sequence_length is not None: softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale else: softmax_scale = attn.scale # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) if attention_mask is not None: attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6 key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, scale=softmax_scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.type_as(query) # Apply output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states class ThinkGenAttnProcessorFlash2Varlen: """ Processor for implementing scaled dot-product attention with flash attention and variable length sequences. This processor implements: - Flash attention with variable length sequences - Rotary position embeddings (RoPE) - Query-Key normalization - Proportional attention scaling Args: None """ def __init__(self) -> None: """Initialize the attention processor.""" if not is_flash_attn_available(): raise ImportError( "ThinkGenAttnProcessorFlash2Varlen requires flash_attn. " "Please install flash_attn." ) def _upad_input( self, query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: torch.Tensor, query_length: int, num_heads: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: """ Unpad the input tensors for flash attention. Args: query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) attention_mask: Attention mask tensor of shape (batch_size, seq_len) query_length: Length of the query sequence num_heads: Number of attention heads Returns: Tuple containing: - Unpadded query tensor - Unpadded key tensor - Unpadded value tensor - Query indices - Tuple of cumulative sequence lengths for query and key - Tuple of maximum sequence lengths for query and key """ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: """Helper function to get unpadding data from attention mask.""" seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return indices, cu_seqlens, max_seqlen_in_batch indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape # Unpad key and value layers key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) # Handle different query length cases if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, base_sequence_length: Optional[int] = None, ) -> torch.Tensor: """ Process attention computation with flash attention. Args: attn: Attention module hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) encoder_hidden_states: Encoder hidden states tensor attention_mask: Optional attention mask tensor image_rotary_emb: Optional rotary embeddings for image tokens base_sequence_length: Optional base sequence length for proportional attention Returns: torch.Tensor: Processed hidden states after attention computation """ batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query_dim = query.shape[-1] inner_dim = key.shape[-1] head_dim = query_dim // attn.heads dtype = query.dtype # Get key-value heads kv_heads = inner_dim // head_dim # Reshape tensors for attention computation query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, kv_heads, head_dim) value = value.view(batch_size, -1, kv_heads, head_dim) # Apply Query-Key normalization if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply Rotary Position Embeddings if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, use_real=False) key = apply_rotary_emb(key, image_rotary_emb, use_real=False) query, key = query.to(dtype), key.to(dtype) # Calculate attention scale if base_sequence_length is not None: softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale else: softmax_scale = attn.scale # Unpad input for flash attention ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens, ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens # Handle different number of heads if kv_heads < attn.heads: key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) # Apply flash attention attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=0.0, causal=False, softmax_scale=softmax_scale, ) # Pad output and apply final transformations hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) hidden_states = hidden_states.flatten(-2) hidden_states = hidden_states.type_as(query) # Apply output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states class ThinkGenTransformerBlock(nn.Module): """ Transformer block for ThinkGen model. This block implements a transformer layer with: - Multi-head attention with flash attention - Feed-forward network with SwiGLU activation - RMS normalization - Optional modulation for conditional generation Args: dim: Dimension of the input and output tensors num_attention_heads: Number of attention heads num_kv_heads: Number of key-value heads multiple_of: Multiple of which the hidden dimension should be ffn_dim_multiplier: Multiplier for the feed-forward network dimension norm_eps: Epsilon value for normalization layers modulation: Whether to use modulation for conditional generation use_fused_rms_norm: Whether to use fused RMS normalization use_fused_swiglu: Whether to use fused SwiGLU activation """ def __init__( self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, ) -> None: """Initialize the transformer block.""" super().__init__() self.head_dim = dim // num_attention_heads self.modulation = modulation try: processor = ThinkGenAttnProcessorFlash2Varlen() except ImportError: processor = ThinkGenAttnProcessor() # Initialize attention layer self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=dim // num_attention_heads, qk_norm="rms_norm", heads=num_attention_heads, kv_heads=num_kv_heads, eps=1e-5, bias=False, out_bias=False, processor=processor, ) # Initialize feed-forward network self.feed_forward = LuminaFeedForward( dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier ) # Initialize normalization layers if modulation: self.norm1 = LuminaRMSNormZero( embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True ) else: self.norm1 = RMSNorm(dim, eps=norm_eps) self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) self.norm2 = RMSNorm(dim, eps=norm_eps) self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) self.initialize_weights() def initialize_weights(self) -> None: """ Initialize the weights of the transformer block. Uses Xavier uniform initialization for linear layers and zero initialization for biases. """ nn.init.xavier_uniform_(self.attn.to_q.weight) nn.init.xavier_uniform_(self.attn.to_k.weight) nn.init.xavier_uniform_(self.attn.to_v.weight) nn.init.xavier_uniform_(self.attn.to_out[0].weight) nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) if self.modulation: nn.init.zeros_(self.norm1.linear.weight) nn.init.zeros_(self.norm1.linear.bias) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the transformer block. Args: hidden_states: Input hidden states tensor attention_mask: Attention mask tensor image_rotary_emb: Rotary embeddings for image tokens temb: Optional timestep embedding tensor Returns: torch.Tensor: Output hidden states after transformer block processing """ enable_taylorseer = getattr(self, 'enable_taylorseer', False) if enable_taylorseer: if self.modulation: if temb is None: raise ValueError("temb must be provided when modulation is enabled") if self.current['type'] == 'full': self.current['module'] = 'total' taylor_cache_init(cache_dic=self.cache_dic, current=self.current) norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) derivative_approximation(cache_dic=self.cache_dic, current=self.current, feature=hidden_states) elif self.current['type'] == 'Taylor': self.current['module'] = 'total' hidden_states = taylor_formula(cache_dic=self.cache_dic, current=self.current) else: norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) else: if self.modulation: if temb is None: raise ValueError("temb must be provided when modulation is enabled") norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) else: norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) return hidden_states class ThinkGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ ThinkGen Transformer 2D Model. A transformer-based diffusion model for image generation with: - Patch-based image processing - Rotary position embeddings - Multi-head attention - Conditional generation support Args: patch_size: Size of image patches in_channels: Number of input channels out_channels: Number of output channels (defaults to in_channels) hidden_size: Size of hidden layers num_layers: Number of transformer layers num_refiner_layers: Number of refiner layers num_attention_heads: Number of attention heads num_kv_heads: Number of key-value heads multiple_of: Multiple of which the hidden dimension should be ffn_dim_multiplier: Multiplier for feed-forward network dimension norm_eps: Epsilon value for normalization layers axes_dim_rope: Dimensions for rotary position embeddings axes_lens: Lengths for rotary position embeddings text_feat_dim: Dimension of text features timestep_scale: Scale factor for timestep embeddings use_fused_rms_norm: Whether to use fused RMS normalization use_fused_swiglu: Whether to use fused SwiGLU activation """ _supports_gradient_checkpointing = True _no_split_modules = ["ThinkGenTransformerBlock"] _skip_layerwise_casting_patterns = ["x_embedder", "norm"] @register_to_config def __init__( self, patch_size: int = 2, in_channels: int = 16, out_channels: Optional[int] = None, hidden_size: int = 2304, num_layers: int = 26, num_refiner_layers: int = 2, num_attention_heads: int = 24, num_kv_heads: int = 8, multiple_of: int = 256, ffn_dim_multiplier: Optional[float] = None, norm_eps: float = 1e-5, axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), axes_lens: Tuple[int, int, int] = (300, 512, 512), text_feat_dim: int = 1024, timestep_scale: float = 1.0 ) -> None: """Initialize the ThinkGen transformer model.""" super().__init__() # Validate configuration if (hidden_size // num_attention_heads) != sum(axes_dim_rope): raise ValueError( f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" ) self.out_channels = out_channels or in_channels # Initialize embeddings self.rope_embedder = ThinkGenRotaryPosEmbed( theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size, ) self.x_embedder = nn.Linear( in_features=patch_size * patch_size * in_channels, out_features=hidden_size, ) self.ref_image_patch_embedder = nn.Linear( in_features=patch_size * patch_size * in_channels, out_features=hidden_size, ) self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( hidden_size=hidden_size, text_feat_dim=text_feat_dim, norm_eps=norm_eps, timestep_scale=timestep_scale ) # Initialize transformer blocks self.noise_refiner = nn.ModuleList([ ThinkGenTransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True ) for _ in range(num_refiner_layers) ]) self.ref_image_refiner = nn.ModuleList([ ThinkGenTransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True ) for _ in range(num_refiner_layers) ]) self.context_refiner = nn.ModuleList( [ ThinkGenTransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False ) for _ in range(num_refiner_layers) ] ) # 3. Transformer blocks self.layers = nn.ModuleList( [ ThinkGenTransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True ) for _ in range(num_layers) ] ) # 4. Output norm & projection self.norm_out = LuminaLayerNormContinuous( embedding_dim=hidden_size, conditioning_embedding_dim=min(hidden_size, 1024), elementwise_affine=False, eps=1e-6, bias=True, out_dim=patch_size * patch_size * self.out_channels ) # Add learnable embeddings to distinguish different images self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images self.gradient_checkpointing = False self.initialize_weights() # TeaCache settings self.enable_teacache = False self.teacache_rel_l1_thresh = 0.05 self.teacache_params = TeaCacheParams() coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487] self.rescale_func = np.poly1d(coefficients) self.prepad_embed = nn.Parameter(torch.randn(1, 23, 8192)) print("add prepad_embed parameter ! ") self.register_buffer('prepad_mask', torch.ones(1, 23).to(torch.int64)) def initialize_weights(self) -> None: """ Initialize the weights of the model. Uses Xavier uniform initialization for linear layers. """ nn.init.xavier_uniform_(self.x_embedder.weight) nn.init.constant_(self.x_embedder.bias, 0.0) nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) nn.init.zeros_(self.norm_out.linear_1.weight) nn.init.zeros_(self.norm_out.linear_1.bias) nn.init.zeros_(self.norm_out.linear_2.weight) nn.init.zeros_(self.norm_out.linear_2.bias) nn.init.normal_(self.image_index_embedding, std=0.02) def img_patch_embed_and_refine( self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb ): batch_size = len(hidden_states) max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)]) hidden_states = self.x_embedder(hidden_states) ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) # 添加image_index_embedding for i in range(batch_size): shift = 0 for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j] shift += ref_img_len for layer in self.noise_refiner: hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) num_ref_images = len(flat_l_effective_ref_img_len) max_ref_img_len = max(flat_l_effective_ref_img_len) batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool) batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size) batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype) batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) # sequence of ref imgs to batch idx = 0 for i in range(batch_size): shift = 0 for ref_img_len in l_effective_ref_img_len[i]: batch_ref_img_mask[idx, :ref_img_len] = True batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len] batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len] batch_temb[idx] = temb[i] shift += ref_img_len idx += 1 # refine ref imgs separately for layer in self.ref_image_refiner: batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb) # batch of ref imgs to sequence idx = 0 for i in range(batch_size): shift = 0 for ref_img_len in l_effective_ref_img_len[i]: ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len] shift += ref_img_len idx += 1 combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size) for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)): combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)] combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len] return combined_img_hidden_states def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): batch_size = len(hidden_states) p = self.config.patch_size device = hidden_states[0].device img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] if ref_image_hidden_states is not None: ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states] l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes] else: ref_img_sizes = [None for _ in range(batch_size)] l_effective_ref_img_len = [[0] for _ in range(batch_size)] max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) max_img_len = max(l_effective_img_len) # ref image patch embeddings flat_ref_img_hidden_states = [] for i in range(batch_size): if ref_img_sizes[i] is not None: imgs = [] for ref_img in ref_image_hidden_states[i]: C, H, W = ref_img.size() ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) imgs.append(ref_img) img = torch.cat(imgs, dim=0) flat_ref_img_hidden_states.append(img) else: flat_ref_img_hidden_states.append(None) # image patch embeddings flat_hidden_states = [] for i in range(batch_size): img = hidden_states[i] C, H, W = img.size() img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) flat_hidden_states.append(img) padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device) for i in range(batch_size): if ref_img_sizes[i] is not None: padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i] padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) for i in range(batch_size): padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i] padded_img_mask[i, :l_effective_img_len[i]] = True return ( padded_hidden_states, padded_ref_img_hidden_states, padded_img_mask, padded_ref_img_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, ) def forward( self, hidden_states: Union[torch.Tensor, List[torch.Tensor]], timestep: torch.Tensor, text_hidden_states: torch.Tensor, freqs_cis: torch.Tensor, text_attention_mask: torch.Tensor, ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = False, ) -> Union[torch.Tensor, Transformer2DModelOutput]: enable_taylorseer = getattr(self, 'enable_taylorseer', False) # if self.prepad_embed.dtype != text_hidden_states.dtype: # self.prepad_embed = self.prepad_embed.to(text_hidden_states.dtype) # if self.prepad_mask.device != text_attention_mask.device: # self.prepad_mask = self.prepad_mask.to(text_attention_mask.device) bs = text_hidden_states.shape[0] prepad_embed = self.prepad_embed.repeat(bs, 1, 1) prepad_mask = self.prepad_mask.repeat(bs, 1) text_hidden_states = torch.cat([prepad_embed, text_hidden_states], dim = 1) text_attention_mask = torch.cat([prepad_mask, text_attention_mask], dim = 1) if enable_taylorseer: cal_type(self.cache_dic, self.current) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) # 1. Condition, positional & patch embedding batch_size = len(hidden_states) is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) if is_hidden_states_tensor: assert hidden_states.ndim == 4 hidden_states = [_hidden_states for _hidden_states in hidden_states] device = hidden_states[0].device temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) ( hidden_states, ref_image_hidden_states, img_mask, ref_img_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) ( context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb, rotary_emb, encoder_seq_lengths, seq_lengths, ) = self.rope_embedder( freqs_cis, text_attention_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device, ) # 2. Context refinement for layer in self.context_refiner: text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) combined_img_hidden_states = self.img_patch_embed_and_refine( hidden_states, ref_image_hidden_states, img_mask, ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, ) # 3. Joint Transformer blocks (joint text embed 和 image embed) max_seq_len = max(seq_lengths) attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): attention_mask[i, :seq_len] = True joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len] joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len] hidden_states = joint_hidden_states if self.enable_teacache: teacache_hidden_states = hidden_states.clone() teacache_temb = temb.clone() modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb) if self.teacache_params.is_first_or_last_step: should_calc = True self.teacache_params.accumulated_rel_l1_distance = 0 else: self.teacache_params.accumulated_rel_l1_distance += self.rescale_func( ((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean() \ / self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item() ) if self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh: should_calc = False else: should_calc = True self.teacache_params.accumulated_rel_l1_distance = 0 self.teacache_params.previous_modulated_inp = modulated_inp if self.enable_teacache: if not should_calc: hidden_states += self.teacache_params.previous_residual else: ori_hidden_states = hidden_states.clone() for layer_idx, layer in enumerate(self.layers): if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( layer, hidden_states, attention_mask, rotary_emb, temb ) else: hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) self.teacache_params.previous_residual = hidden_states - ori_hidden_states else: if enable_taylorseer: self.current['stream'] = 'layers_stream' for layer_idx, layer in enumerate(self.layers): if enable_taylorseer: layer.current = self.current layer.cache_dic = self.cache_dic layer.enable_taylorseer = True self.current['layer'] = layer_idx if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( layer, hidden_states, attention_mask, rotary_emb, temb ) else: hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) # 4. Output norm & projection hidden_states = self.norm_out(hidden_states, temb) p = self.config.patch_size output = [] for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)): height, width = img_size output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p)) if is_hidden_states_tensor: output = torch.stack(output, dim=0) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if enable_taylorseer: self.current['step'] += 1 if not return_dict: return output return Transformer2DModelOutput(sample=output)