|
|
import torch |
|
|
from typing import Union, Tuple, List, Optional |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
USE_FP32_ROPE_FREQS = True |
|
|
ROPE_FREQS_DTYPE = torch.bfloat16 |
|
|
|
|
|
|
|
|
def set_use_fp32_rope_freqs(enabled: bool) -> None: |
|
|
global USE_FP32_ROPE_FREQS |
|
|
USE_FP32_ROPE_FREQS = bool(enabled) |
|
|
|
|
|
|
|
|
def set_rope_freqs_dtype(dtype: torch.dtype) -> None: |
|
|
global ROPE_FREQS_DTYPE |
|
|
ROPE_FREQS_DTYPE = dtype |
|
|
|
|
|
|
|
|
def _rope_freqs_dtype() -> torch.dtype: |
|
|
return torch.float32 if USE_FP32_ROPE_FREQS else ROPE_FREQS_DTYPE |
|
|
|
|
|
|
|
|
def _coerce_rope_positions(pos): |
|
|
rope_dtype = _rope_freqs_dtype() |
|
|
if isinstance(pos, int): |
|
|
return torch.arange(pos, dtype=rope_dtype) |
|
|
if isinstance(pos, np.ndarray): |
|
|
return torch.from_numpy(pos).to(dtype=rope_dtype) |
|
|
return pos.to(dtype=rope_dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_1d_rotary_pos_embed_riflex( |
|
|
dim: int, |
|
|
pos: Union[np.ndarray, int], |
|
|
theta: float = 10000.0, |
|
|
use_real=False, |
|
|
k: Optional[int] = None, |
|
|
L_test: Optional[int] = None, |
|
|
): |
|
|
""" |
|
|
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions. |
|
|
|
|
|
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end |
|
|
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 |
|
|
data type. |
|
|
|
|
|
Args: |
|
|
dim (`int`): Dimension of the frequency tensor. |
|
|
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar |
|
|
theta (`float`, *optional*, defaults to 10000.0): |
|
|
Scaling factor for frequency computation. Defaults to 10000.0. |
|
|
use_real (`bool`, *optional*): |
|
|
If True, return real part and imaginary part separately. Otherwise, return complex numbers. |
|
|
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE |
|
|
L_test (`int`, *optional*, defaults to None): the number of frames for inference |
|
|
Returns: |
|
|
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] |
|
|
""" |
|
|
assert dim % 2 == 0 |
|
|
|
|
|
pos = _coerce_rope_positions(pos) |
|
|
|
|
|
freqs = 1.0 / ( |
|
|
theta ** (torch.arange(0, dim, 2, device=pos.device, dtype=pos.dtype)[: (dim // 2)] / dim) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if k is not None: |
|
|
freqs[k-1] = 0.9 * 2 * torch.pi / L_test |
|
|
|
|
|
freqs = torch.outer(pos, freqs) |
|
|
|
|
|
if use_real: |
|
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) |
|
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) |
|
|
return freqs_cos, freqs_sin |
|
|
else: |
|
|
|
|
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
|
return freqs_cis |
|
|
|
|
|
def identify_k( b: float, d: int, N: int): |
|
|
""" |
|
|
This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer. |
|
|
|
|
|
Args: |
|
|
b (`float`): The base frequency for RoPE. |
|
|
d (`int`): Dimension of the frequency tensor |
|
|
N (`int`): the first observed repetition frame in latent space |
|
|
Returns: |
|
|
k (`int`): the index of intrinsic frequency component |
|
|
N_k (`int`): the period of intrinsic frequency component in latent space |
|
|
Example: |
|
|
In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space). |
|
|
k, N_k = identify_k(b=256, d=16, N=48) |
|
|
In this case, the intrinsic frequency index k is 4, and the period N_k is 50. |
|
|
""" |
|
|
|
|
|
|
|
|
periods = [] |
|
|
for j in range(1, d // 2 + 1): |
|
|
theta_j = 1.0 / (b ** (2 * (j - 1) / d)) |
|
|
N_j = round(2 * torch.pi / theta_j) |
|
|
periods.append(N_j) |
|
|
|
|
|
|
|
|
diffs = [abs(N_j - N) for N_j in periods] |
|
|
k = diffs.index(min(diffs)) + 1 |
|
|
N_k = periods[k-1] |
|
|
return k, N_k |
|
|
|
|
|
def _to_tuple(x, dim=2): |
|
|
if isinstance(x, int): |
|
|
return (x,) * dim |
|
|
elif len(x) == dim: |
|
|
return x |
|
|
else: |
|
|
raise ValueError(f"Expected length {dim} or int, but got {x}") |
|
|
|
|
|
|
|
|
def get_meshgrid_nd(start, *args, dim=2, dtype=None, device=None): |
|
|
""" |
|
|
Get n-D meshgrid with start, stop and num. |
|
|
|
|
|
Args: |
|
|
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, |
|
|
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num |
|
|
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in |
|
|
n-tuples. |
|
|
*args: See above. |
|
|
dim (int): Dimension of the meshgrid. Defaults to 2. |
|
|
|
|
|
Returns: |
|
|
grid (np.ndarray): [dim, ...] |
|
|
""" |
|
|
grid_dtype = torch.float32 if dtype is None else dtype |
|
|
if len(args) == 0: |
|
|
|
|
|
num = _to_tuple(start, dim=dim) |
|
|
start = (0,) * dim |
|
|
stop = num |
|
|
elif len(args) == 1: |
|
|
|
|
|
start = _to_tuple(start, dim=dim) |
|
|
stop = _to_tuple(args[0], dim=dim) |
|
|
num = [stop[i] - start[i] for i in range(dim)] |
|
|
elif len(args) == 2: |
|
|
|
|
|
start = _to_tuple(start, dim=dim) |
|
|
stop = _to_tuple(args[0], dim=dim) |
|
|
num = _to_tuple(args[1], dim=dim) |
|
|
else: |
|
|
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") |
|
|
|
|
|
|
|
|
axis_grid = [] |
|
|
for i in range(dim): |
|
|
a, b, n = start[i], stop[i], num[i] |
|
|
if a == b: |
|
|
g = torch.tensor([a], dtype=grid_dtype, device=device) |
|
|
else: |
|
|
g = torch.linspace(a, b, n + 1, dtype=grid_dtype, device=device)[:n] |
|
|
axis_grid.append(g) |
|
|
grid = torch.meshgrid(*axis_grid, indexing="ij") |
|
|
grid = torch.stack(grid, dim=0) |
|
|
|
|
|
return grid |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reshape_for_broadcast( |
|
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], |
|
|
x: torch.Tensor, |
|
|
head_first=False, |
|
|
): |
|
|
""" |
|
|
Reshape frequency tensor for broadcasting it with another tensor. |
|
|
|
|
|
This function reshapes the frequency tensor to have the same shape as the target tensor 'x' |
|
|
for the purpose of broadcasting the frequency tensor during element-wise operations. |
|
|
|
|
|
Notes: |
|
|
When using FlashMHAModified, head_first should be False. |
|
|
When using Attention, head_first should be True. |
|
|
|
|
|
Args: |
|
|
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. |
|
|
x (torch.Tensor): Target tensor for broadcasting compatibility. |
|
|
head_first (bool): head dimension first (except batch dim) or not. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Reshaped frequency tensor. |
|
|
|
|
|
Raises: |
|
|
AssertionError: If the frequency tensor doesn't match the expected shape. |
|
|
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. |
|
|
""" |
|
|
ndim = x.ndim |
|
|
assert 0 <= 1 < ndim |
|
|
|
|
|
if isinstance(freqs_cis, tuple): |
|
|
|
|
|
if head_first: |
|
|
assert freqs_cis[0].shape == ( |
|
|
x.shape[-2], |
|
|
x.shape[-1], |
|
|
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" |
|
|
shape = [ |
|
|
d if i == ndim - 2 or i == ndim - 1 else 1 |
|
|
for i, d in enumerate(x.shape) |
|
|
] |
|
|
else: |
|
|
assert freqs_cis[0].shape == ( |
|
|
x.shape[1], |
|
|
x.shape[-1], |
|
|
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" |
|
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
|
|
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) |
|
|
else: |
|
|
|
|
|
if head_first: |
|
|
assert freqs_cis.shape == ( |
|
|
x.shape[-2], |
|
|
x.shape[-1], |
|
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" |
|
|
shape = [ |
|
|
d if i == ndim - 2 or i == ndim - 1 else 1 |
|
|
for i, d in enumerate(x.shape) |
|
|
] |
|
|
else: |
|
|
assert freqs_cis.shape == ( |
|
|
x.shape[1], |
|
|
x.shape[-1], |
|
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" |
|
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
|
|
return freqs_cis.view(*shape) |
|
|
|
|
|
|
|
|
|
|
|
def _apply_rope_inplace_inner(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> None: |
|
|
x_view = x.view(*x.shape[:-1], -1, 2) |
|
|
cos_view = cos.view(*cos.shape[:-1], -1, 2) |
|
|
sin_view = sin.view(*sin.shape[:-1], -1, 2) |
|
|
x0 = x_view[..., 0] |
|
|
x1 = x_view[..., 1] |
|
|
x0_orig = x0.clone() |
|
|
x0.mul_(cos_view[..., 0]).addcmul_(x1, sin_view[..., 0], value=-1) |
|
|
x1.mul_(cos_view[..., 1]).addcmul_(x0_orig, sin_view[..., 1]) |
|
|
|
|
|
|
|
|
def _apply_rope_inplace(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, use_fp32: bool) -> torch.Tensor: |
|
|
if use_fp32 and x.dtype != torch.float32: |
|
|
x_work = x.to(torch.float32) |
|
|
_apply_rope_inplace_inner(x_work, cos, sin) |
|
|
x.copy_(x_work.to(x.dtype)) |
|
|
return x |
|
|
_apply_rope_inplace_inner(x, cos, sin) |
|
|
return x |
|
|
|
|
|
def apply_rotary_emb_single( qklist, |
|
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
|
|
head_first: bool = False, |
|
|
): |
|
|
xq = qklist[0] |
|
|
qklist.clear() |
|
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) |
|
|
use_fp32 = USE_FP32_ROPE_FREQS |
|
|
target_dtype = torch.float32 if use_fp32 else xq.dtype |
|
|
if cos.device != xq.device or cos.dtype != target_dtype: |
|
|
cos = cos.to(device=xq.device, dtype=target_dtype) |
|
|
if sin.device != xq.device or sin.dtype != target_dtype: |
|
|
sin = sin.to(device=xq.device, dtype=target_dtype) |
|
|
_apply_rope_inplace(xq, cos, sin, use_fp32) |
|
|
return xq |
|
|
|
|
|
|
|
|
def apply_rotary_emb( qklist, |
|
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
|
|
head_first: bool = False, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Apply rotary embeddings to input tensors using the given frequency tensor. |
|
|
|
|
|
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided |
|
|
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor |
|
|
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are |
|
|
returned as real tensors. |
|
|
|
|
|
Args: |
|
|
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] |
|
|
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] |
|
|
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. |
|
|
head_first (bool): head dimension first (except batch dim) or not. |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. |
|
|
|
|
|
""" |
|
|
xq, xk = qklist |
|
|
qklist.clear() |
|
|
if isinstance(freqs_cis, tuple): |
|
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) |
|
|
use_fp32 = USE_FP32_ROPE_FREQS |
|
|
target_dtype = torch.float32 if use_fp32 else xq.dtype |
|
|
if cos.device != xq.device or cos.dtype != target_dtype: |
|
|
cos = cos.to(device=xq.device, dtype=target_dtype) |
|
|
if sin.device != xq.device or sin.dtype != target_dtype: |
|
|
sin = sin.to(device=xq.device, dtype=target_dtype) |
|
|
_apply_rope_inplace(xq, cos, sin, use_fp32) |
|
|
_apply_rope_inplace(xk, cos, sin, use_fp32) |
|
|
xq_out = xq |
|
|
xk_out = xk |
|
|
else: |
|
|
|
|
|
xq_ = torch.view_as_complex( |
|
|
xq.float().reshape(*xq.shape[:-1], -1, 2) |
|
|
) |
|
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( |
|
|
xq.device |
|
|
) |
|
|
|
|
|
|
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) |
|
|
xk_ = torch.view_as_complex( |
|
|
xk.float().reshape(*xk.shape[:-1], -1, 2) |
|
|
) |
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) |
|
|
|
|
|
return xq_out, xk_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return xq_out, xk_out |
|
|
def get_nd_rotary_pos_embed( |
|
|
start, |
|
|
*args, |
|
|
theta=10000.0, |
|
|
use_real=True, |
|
|
theta_rescale_factor: Union[float, List[float]] = 1.0, |
|
|
interpolation_factor: Union[float, List[float]] = 1.0, |
|
|
k = 6, |
|
|
L_test = 66, |
|
|
enable_riflex = False, |
|
|
rope_dim_list = [44, 42, 42], |
|
|
head_dim = 128, |
|
|
): |
|
|
""" |
|
|
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. |
|
|
|
|
|
Args: |
|
|
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. |
|
|
sum(rope_dim_list) should equal to head_dim of attention layer. |
|
|
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, |
|
|
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. |
|
|
*args: See above. |
|
|
theta (float): Scaling factor for frequency computation. Defaults to 10000.0. |
|
|
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. |
|
|
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real |
|
|
part and an imaginary part separately. |
|
|
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. |
|
|
|
|
|
Returns: |
|
|
pos_embed (torch.Tensor): [HW, D/2] |
|
|
""" |
|
|
if rope_dim_list is None: |
|
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] |
|
|
assert ( |
|
|
sum(rope_dim_list) == head_dim |
|
|
), "sum(rope_dim_list) should equal to head_dim of attention layer" |
|
|
|
|
|
grid = get_meshgrid_nd( |
|
|
start, *args, dim=len(rope_dim_list), dtype=_rope_freqs_dtype() |
|
|
) |
|
|
|
|
|
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): |
|
|
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) |
|
|
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: |
|
|
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) |
|
|
assert len(theta_rescale_factor) == len( |
|
|
rope_dim_list |
|
|
), "len(theta_rescale_factor) should equal to len(rope_dim_list)" |
|
|
|
|
|
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): |
|
|
interpolation_factor = [interpolation_factor] * len(rope_dim_list) |
|
|
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: |
|
|
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) |
|
|
assert len(interpolation_factor) == len( |
|
|
rope_dim_list |
|
|
), "len(interpolation_factor) should equal to len(rope_dim_list)" |
|
|
|
|
|
|
|
|
embs = [] |
|
|
for i in range(len(rope_dim_list)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if i == 0 and enable_riflex: |
|
|
emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test) |
|
|
|
|
|
else: |
|
|
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],) |
|
|
embs.append(emb) |
|
|
|
|
|
if use_real: |
|
|
cos = torch.cat([emb[0] for emb in embs], dim=1) |
|
|
sin = torch.cat([emb[1] for emb in embs], dim=1) |
|
|
return cos, sin |
|
|
else: |
|
|
emb = torch.cat(embs, dim=1) |
|
|
return emb |
|
|
|
|
|
|
|
|
def get_1d_rotary_pos_embed( |
|
|
dim: int, |
|
|
pos: Union[torch.FloatTensor, int], |
|
|
theta: float = 10000.0, |
|
|
use_real: bool = False, |
|
|
theta_rescale_factor: float = 1.0, |
|
|
interpolation_factor: float = 1.0, |
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
""" |
|
|
Precompute the frequency tensor for complex exponential (cis) with given dimensions. |
|
|
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) |
|
|
|
|
|
This function calculates a frequency tensor with complex exponential using the given dimension 'dim' |
|
|
and the end index 'end'. The 'theta' parameter scales the frequencies. |
|
|
The returned tensor contains complex values in complex64 data type. |
|
|
|
|
|
Args: |
|
|
dim (int): Dimension of the frequency tensor. |
|
|
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar |
|
|
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. |
|
|
use_real (bool, optional): If True, return real part and imaginary part separately. |
|
|
Otherwise, return complex numbers. |
|
|
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. |
|
|
|
|
|
Returns: |
|
|
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] |
|
|
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] |
|
|
""" |
|
|
pos = _coerce_rope_positions(pos) |
|
|
|
|
|
|
|
|
|
|
|
if theta_rescale_factor != 1.0: |
|
|
theta *= theta_rescale_factor ** (dim / (dim - 2)) |
|
|
|
|
|
freqs = 1.0 / ( |
|
|
theta ** (torch.arange(0, dim, 2, device=pos.device, dtype=pos.dtype)[: (dim // 2)] / dim) |
|
|
) |
|
|
|
|
|
freqs = torch.outer(pos * interpolation_factor, freqs) |
|
|
if use_real: |
|
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) |
|
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) |
|
|
return freqs_cos, freqs_sin |
|
|
else: |
|
|
freqs_cis = torch.polar( |
|
|
torch.ones_like(freqs), freqs |
|
|
) |
|
|
return freqs_cis |
|
|
|
|
|
def get_rotary_pos_embed(latents_size, enable_RIFLEx = False): |
|
|
target_ndim = 3 |
|
|
ndim = 5 - 2 |
|
|
|
|
|
patch_size = [1, 2, 2] |
|
|
if isinstance(patch_size, int): |
|
|
assert all(s % patch_size == 0 for s in latents_size), ( |
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), " |
|
|
f"but got {latents_size}." |
|
|
) |
|
|
rope_sizes = [s // patch_size for s in latents_size] |
|
|
elif isinstance(patch_size, list): |
|
|
assert all( |
|
|
s % patch_size[idx] == 0 |
|
|
for idx, s in enumerate(latents_size) |
|
|
), ( |
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), " |
|
|
f"but got {latents_size}." |
|
|
) |
|
|
rope_sizes = [ |
|
|
s // patch_size[idx] for idx, s in enumerate(latents_size) |
|
|
] |
|
|
|
|
|
if len(rope_sizes) != target_ndim: |
|
|
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes |
|
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed( |
|
|
rope_sizes, |
|
|
theta=10000, |
|
|
use_real=True, |
|
|
theta_rescale_factor=1, |
|
|
L_test = latents_size[0], |
|
|
enable_riflex = enable_RIFLEx |
|
|
) |
|
|
return (freqs_cos, freqs_sin) |
|
|
|