| import torch
|
| from typing import Union, Tuple, List, Optional
|
| import numpy as np
|
|
|
|
|
| USE_FP32_ROPE_FREQS = True
|
| ROPE_FREQS_DTYPE = torch.bfloat16
|
|
|
|
|
| def set_use_fp32_rope_freqs(enabled: bool) -> None:
|
| global USE_FP32_ROPE_FREQS
|
| USE_FP32_ROPE_FREQS = bool(enabled)
|
|
|
|
|
| def set_rope_freqs_dtype(dtype: torch.dtype) -> None:
|
| global ROPE_FREQS_DTYPE
|
| ROPE_FREQS_DTYPE = dtype
|
|
|
|
|
| def _rope_freqs_dtype() -> torch.dtype:
|
| return torch.float32 if USE_FP32_ROPE_FREQS else ROPE_FREQS_DTYPE
|
|
|
|
|
| def _coerce_rope_positions(pos):
|
| rope_dtype = _rope_freqs_dtype()
|
| if isinstance(pos, int):
|
| return torch.arange(pos, dtype=rope_dtype)
|
| if isinstance(pos, np.ndarray):
|
| return torch.from_numpy(pos).to(dtype=rope_dtype)
|
| return pos.to(dtype=rope_dtype)
|
|
|
|
|
|
|
|
|
| def get_1d_rotary_pos_embed_riflex(
|
| dim: int,
|
| pos: Union[np.ndarray, int],
|
| theta: float = 10000.0,
|
| use_real=False,
|
| k: Optional[int] = None,
|
| L_test: Optional[int] = None,
|
| ):
|
| """
|
| RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
|
|
| This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
| index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
| data type.
|
|
|
| Args:
|
| dim (`int`): Dimension of the frequency tensor.
|
| pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
| theta (`float`, *optional*, defaults to 10000.0):
|
| Scaling factor for frequency computation. Defaults to 10000.0.
|
| use_real (`bool`, *optional*):
|
| If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
|
| L_test (`int`, *optional*, defaults to None): the number of frames for inference
|
| Returns:
|
| `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
| """
|
| assert dim % 2 == 0
|
|
|
| pos = _coerce_rope_positions(pos)
|
|
|
| freqs = 1.0 / (
|
| theta ** (torch.arange(0, dim, 2, device=pos.device, dtype=pos.dtype)[: (dim // 2)] / dim)
|
| )
|
|
|
|
|
|
|
|
|
|
|
| if k is not None:
|
| freqs[k-1] = 0.9 * 2 * torch.pi / L_test
|
|
|
| freqs = torch.outer(pos, freqs)
|
|
|
| if use_real:
|
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
|
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
|
| return freqs_cos, freqs_sin
|
| else:
|
|
|
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| return freqs_cis
|
|
|
| def identify_k( b: float, d: int, N: int):
|
| """
|
| This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
|
|
|
| Args:
|
| b (`float`): The base frequency for RoPE.
|
| d (`int`): Dimension of the frequency tensor
|
| N (`int`): the first observed repetition frame in latent space
|
| Returns:
|
| k (`int`): the index of intrinsic frequency component
|
| N_k (`int`): the period of intrinsic frequency component in latent space
|
| Example:
|
| In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
|
| k, N_k = identify_k(b=256, d=16, N=48)
|
| In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
|
| """
|
|
|
|
|
| periods = []
|
| for j in range(1, d // 2 + 1):
|
| theta_j = 1.0 / (b ** (2 * (j - 1) / d))
|
| N_j = round(2 * torch.pi / theta_j)
|
| periods.append(N_j)
|
|
|
|
|
| diffs = [abs(N_j - N) for N_j in periods]
|
| k = diffs.index(min(diffs)) + 1
|
| N_k = periods[k-1]
|
| return k, N_k
|
|
|
| def _to_tuple(x, dim=2):
|
| if isinstance(x, int):
|
| return (x,) * dim
|
| elif len(x) == dim:
|
| return x
|
| else:
|
| raise ValueError(f"Expected length {dim} or int, but got {x}")
|
|
|
|
|
| def get_meshgrid_nd(start, *args, dim=2, dtype=None, device=None):
|
| """
|
| Get n-D meshgrid with start, stop and num.
|
|
|
| Args:
|
| start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
| step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
| should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
| n-tuples.
|
| *args: See above.
|
| dim (int): Dimension of the meshgrid. Defaults to 2.
|
|
|
| Returns:
|
| grid (np.ndarray): [dim, ...]
|
| """
|
| grid_dtype = torch.float32 if dtype is None else dtype
|
| if len(args) == 0:
|
|
|
| num = _to_tuple(start, dim=dim)
|
| start = (0,) * dim
|
| stop = num
|
| elif len(args) == 1:
|
|
|
| start = _to_tuple(start, dim=dim)
|
| stop = _to_tuple(args[0], dim=dim)
|
| num = [stop[i] - start[i] for i in range(dim)]
|
| elif len(args) == 2:
|
|
|
| start = _to_tuple(start, dim=dim)
|
| stop = _to_tuple(args[0], dim=dim)
|
| num = _to_tuple(args[1], dim=dim)
|
| else:
|
| raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
|
|
|
|
| axis_grid = []
|
| for i in range(dim):
|
| a, b, n = start[i], stop[i], num[i]
|
| if a == b:
|
| g = torch.tensor([a], dtype=grid_dtype, device=device)
|
| else:
|
| g = torch.linspace(a, b, n + 1, dtype=grid_dtype, device=device)[:n]
|
| axis_grid.append(g)
|
| grid = torch.meshgrid(*axis_grid, indexing="ij")
|
| grid = torch.stack(grid, dim=0)
|
|
|
| return grid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def reshape_for_broadcast(
|
| freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| x: torch.Tensor,
|
| head_first=False,
|
| ):
|
| """
|
| Reshape frequency tensor for broadcasting it with another tensor.
|
|
|
| This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
| for the purpose of broadcasting the frequency tensor during element-wise operations.
|
|
|
| Notes:
|
| When using FlashMHAModified, head_first should be False.
|
| When using Attention, head_first should be True.
|
|
|
| Args:
|
| freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
| x (torch.Tensor): Target tensor for broadcasting compatibility.
|
| head_first (bool): head dimension first (except batch dim) or not.
|
|
|
| Returns:
|
| torch.Tensor: Reshaped frequency tensor.
|
|
|
| Raises:
|
| AssertionError: If the frequency tensor doesn't match the expected shape.
|
| AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
| """
|
| ndim = x.ndim
|
| assert 0 <= 1 < ndim
|
|
|
| if isinstance(freqs_cis, tuple):
|
|
|
| if head_first:
|
| assert freqs_cis[0].shape == (
|
| x.shape[-2],
|
| x.shape[-1],
|
| ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
| shape = [
|
| d if i == ndim - 2 or i == ndim - 1 else 1
|
| for i, d in enumerate(x.shape)
|
| ]
|
| else:
|
| assert freqs_cis[0].shape == (
|
| x.shape[1],
|
| x.shape[-1],
|
| ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
| else:
|
|
|
| if head_first:
|
| assert freqs_cis.shape == (
|
| x.shape[-2],
|
| x.shape[-1],
|
| ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
| shape = [
|
| d if i == ndim - 2 or i == ndim - 1 else 1
|
| for i, d in enumerate(x.shape)
|
| ]
|
| else:
|
| assert freqs_cis.shape == (
|
| x.shape[1],
|
| x.shape[-1],
|
| ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| return freqs_cis.view(*shape)
|
|
|
|
|
|
|
| def _apply_rope_inplace_inner(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> None:
|
| x_view = x.view(*x.shape[:-1], -1, 2)
|
| cos_view = cos.view(*cos.shape[:-1], -1, 2)
|
| sin_view = sin.view(*sin.shape[:-1], -1, 2)
|
| x0 = x_view[..., 0]
|
| x1 = x_view[..., 1]
|
| x0_orig = x0.clone()
|
| x0.mul_(cos_view[..., 0]).addcmul_(x1, sin_view[..., 0], value=-1)
|
| x1.mul_(cos_view[..., 1]).addcmul_(x0_orig, sin_view[..., 1])
|
|
|
|
|
| def _apply_rope_inplace(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, use_fp32: bool) -> torch.Tensor:
|
| if use_fp32 and x.dtype != torch.float32:
|
| x_work = x.to(torch.float32)
|
| _apply_rope_inplace_inner(x_work, cos, sin)
|
| x.copy_(x_work.to(x.dtype))
|
| return x
|
| _apply_rope_inplace_inner(x, cos, sin)
|
| return x
|
|
|
| def apply_rotary_emb_single( qklist,
|
| freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
| head_first: bool = False,
|
| ):
|
| xq = qklist[0]
|
| qklist.clear()
|
| cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
| use_fp32 = USE_FP32_ROPE_FREQS
|
| target_dtype = torch.float32 if use_fp32 else xq.dtype
|
| if cos.device != xq.device or cos.dtype != target_dtype:
|
| cos = cos.to(device=xq.device, dtype=target_dtype)
|
| if sin.device != xq.device or sin.dtype != target_dtype:
|
| sin = sin.to(device=xq.device, dtype=target_dtype)
|
| _apply_rope_inplace(xq, cos, sin, use_fp32)
|
| return xq
|
|
|
|
|
| def apply_rotary_emb( qklist,
|
| freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
| head_first: bool = False,
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Apply rotary embeddings to input tensors using the given frequency tensor.
|
|
|
| This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
| frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
| is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
| returned as real tensors.
|
|
|
| Args:
|
| xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
| xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
| freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
| head_first (bool): head dimension first (except batch dim) or not.
|
|
|
| Returns:
|
| Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
|
|
| """
|
| xq, xk = qklist
|
| qklist.clear()
|
| if isinstance(freqs_cis, tuple):
|
| cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
| use_fp32 = USE_FP32_ROPE_FREQS
|
| target_dtype = torch.float32 if use_fp32 else xq.dtype
|
| if cos.device != xq.device or cos.dtype != target_dtype:
|
| cos = cos.to(device=xq.device, dtype=target_dtype)
|
| if sin.device != xq.device or sin.dtype != target_dtype:
|
| sin = sin.to(device=xq.device, dtype=target_dtype)
|
| _apply_rope_inplace(xq, cos, sin, use_fp32)
|
| _apply_rope_inplace(xk, cos, sin, use_fp32)
|
| xq_out = xq
|
| xk_out = xk
|
| else:
|
|
|
| xq_ = torch.view_as_complex(
|
| xq.float().reshape(*xq.shape[:-1], -1, 2)
|
| )
|
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
| xq.device
|
| )
|
|
|
|
|
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
| xk_ = torch.view_as_complex(
|
| xk.float().reshape(*xk.shape[:-1], -1, 2)
|
| )
|
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
|
|
| return xq_out, xk_out
|
|
|
|
|
|
|
|
|
| return xq_out, xk_out
|
| def get_nd_rotary_pos_embed(
|
| start,
|
| *args,
|
| theta=10000.0,
|
| use_real=True,
|
| theta_rescale_factor: Union[float, List[float]] = 1.0,
|
| interpolation_factor: Union[float, List[float]] = 1.0,
|
| k = 6,
|
| L_test = 66,
|
| enable_riflex = False,
|
| rope_dim_list = [44, 42, 42],
|
| head_dim = 128,
|
| ):
|
| """
|
| This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
|
|
| Args:
|
| rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
| sum(rope_dim_list) should equal to head_dim of attention layer.
|
| start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
| args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
| *args: See above.
|
| theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
| use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
| part and an imaginary part separately.
|
| theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
|
|
| Returns:
|
| pos_embed (torch.Tensor): [HW, D/2]
|
| """
|
| if rope_dim_list is None:
|
| rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
| assert (
|
| sum(rope_dim_list) == head_dim
|
| ), "sum(rope_dim_list) should equal to head_dim of attention layer"
|
|
|
| grid = get_meshgrid_nd(
|
| start, *args, dim=len(rope_dim_list), dtype=_rope_freqs_dtype()
|
| )
|
|
|
| if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
| theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
| elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
| theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
| assert len(theta_rescale_factor) == len(
|
| rope_dim_list
|
| ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
|
|
| if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
| interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
| elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
| interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
| assert len(interpolation_factor) == len(
|
| rope_dim_list
|
| ), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
|
|
|
|
| embs = []
|
| for i in range(len(rope_dim_list)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if i == 0 and enable_riflex:
|
| emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test)
|
|
|
| else:
|
| emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],)
|
| embs.append(emb)
|
|
|
| if use_real:
|
| cos = torch.cat([emb[0] for emb in embs], dim=1)
|
| sin = torch.cat([emb[1] for emb in embs], dim=1)
|
| return cos, sin
|
| else:
|
| emb = torch.cat(embs, dim=1)
|
| return emb
|
|
|
|
|
| def get_1d_rotary_pos_embed(
|
| dim: int,
|
| pos: Union[torch.FloatTensor, int],
|
| theta: float = 10000.0,
|
| use_real: bool = False,
|
| theta_rescale_factor: float = 1.0,
|
| interpolation_factor: float = 1.0,
|
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| """
|
| Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
| (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
|
|
| This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
| and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| The returned tensor contains complex values in complex64 data type.
|
|
|
| Args:
|
| dim (int): Dimension of the frequency tensor.
|
| pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
| theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| use_real (bool, optional): If True, return real part and imaginary part separately.
|
| Otherwise, return complex numbers.
|
| theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
|
|
| Returns:
|
| freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
| freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
| """
|
| pos = _coerce_rope_positions(pos)
|
|
|
|
|
|
|
| if theta_rescale_factor != 1.0:
|
| theta *= theta_rescale_factor ** (dim / (dim - 2))
|
|
|
| freqs = 1.0 / (
|
| theta ** (torch.arange(0, dim, 2, device=pos.device, dtype=pos.dtype)[: (dim // 2)] / dim)
|
| )
|
|
|
| freqs = torch.outer(pos * interpolation_factor, freqs)
|
| if use_real:
|
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
|
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
|
| return freqs_cos, freqs_sin
|
| else:
|
| freqs_cis = torch.polar(
|
| torch.ones_like(freqs), freqs
|
| )
|
| return freqs_cis
|
|
|
| def get_rotary_pos_embed(latents_size, enable_RIFLEx = False):
|
| target_ndim = 3
|
| ndim = 5 - 2
|
|
|
| patch_size = [1, 2, 2]
|
| if isinstance(patch_size, int):
|
| assert all(s % patch_size == 0 for s in latents_size), (
|
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
|
| f"but got {latents_size}."
|
| )
|
| rope_sizes = [s // patch_size for s in latents_size]
|
| elif isinstance(patch_size, list):
|
| assert all(
|
| s % patch_size[idx] == 0
|
| for idx, s in enumerate(latents_size)
|
| ), (
|
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
|
| f"but got {latents_size}."
|
| )
|
| rope_sizes = [
|
| s // patch_size[idx] for idx, s in enumerate(latents_size)
|
| ]
|
|
|
| if len(rope_sizes) != target_ndim:
|
| rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes
|
| freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
| rope_sizes,
|
| theta=10000,
|
| use_real=True,
|
| theta_rescale_factor=1,
|
| L_test = latents_size[0],
|
| enable_riflex = enable_RIFLEx
|
| )
|
| return (freqs_cos, freqs_sin)
|
|
|