| | import math |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from .utils import logger |
| | from .utils import get_hankel |
| |
|
| | def get_spectral_filters( |
| | seq_len: int, |
| | K: int, |
| | use_hankel_L: bool = False, |
| | device: torch.device = None, |
| | dtype: torch.dtype = torch.bfloat16, |
| | ) -> torch.Tensor: |
| | |
| | Z = get_hankel(seq_len, use_hankel_L, device=device, dtype=dtype) |
| |
|
| | |
| | Z_float32 = Z.to(torch.float32) |
| |
|
| | |
| | sigma, phi = torch.linalg.eigh(Z_float32) |
| |
|
| | |
| | sigma = sigma.to(dtype=dtype) |
| | phi = phi.to(dtype=dtype) |
| |
|
| | |
| | sigma_k, phi_k = sigma[-K:], phi[:, -K:] |
| |
|
| | |
| | phi_k = phi_k * sigma_k ** 0.25 |
| |
|
| | |
| | filters = phi_k.to(device=device, dtype=dtype) |
| |
|
| | return filters |
| |
|
| |
|
| | def compute_dimensions(n: int) -> tuple[int, int, int]: |
| | if n <= 2: |
| | raise ValueError("n must be greater than 2") |
| |
|
| | T_prime = (math.ceil(math.sqrt(n - 2)))**2 + 2 |
| | sqrt_T_prime = math.ceil(math.sqrt(T_prime - 2)) |
| | k_max = sqrt_T_prime |
| | return T_prime, sqrt_T_prime, k_max |
| |
|
| | def get_tensorized_spectral_filters_explicit(n: int, k: int, device: torch.device) -> torch.Tensor: |
| | T_prime, sqrt_T_prime, k_max = compute_dimensions(n) |
| | k = min(k, k_max) |
| |
|
| | Z = get_hankel(sqrt_T_prime).to(device) |
| | sigma, phi = torch.linalg.eigh(Z) |
| | sigma_k = sigma[-k:] |
| | phi_k = phi[:, -k:] |
| |
|
| | result = torch.zeros(sqrt_T_prime * sqrt_T_prime, device=device) |
| | |
| | for i in range(k): |
| | for j in range(k): |
| | phi_i = phi_k[:, i] * (sigma_k[i] ** 0.25) |
| | phi_j = phi_k[:, j] * (sigma_k[j] ** 0.25) |
| | kron = torch.kron(phi_i, phi_j) |
| | result += kron |
| | |
| | return result |
| |
|
| |
|
| | def get_tensorized_spectral_filters( |
| | n: int = 8192, |
| | k: int = 24, |
| | use_hankel_L: bool = False, |
| | device: torch.device = None, |
| | dtype: torch.dtype = torch.bfloat16, |
| | ) -> torch.Tensor: |
| | """ |
| | Compute tensorized spectral filters for given sequence length and filter count. |
| | |
| | Args: |
| | n: Sequence length |
| | k: Number of filters |
| | use_hankel_L: Hankel_main ⊗ Hankel_L? Default is Hankel_main ⊗ Hankel_main. |
| | device: Computation device |
| | dtype: Computation dtype |
| | """ |
| | assert torch.cuda.is_available(), "CUDA is required." |
| |
|
| | T_prime, sqrt_T_prime, k_max = compute_dimensions(n) |
| | k = min(k, k_max) |
| |
|
| | Z = get_hankel(sqrt_T_prime) |
| | sigma, phi = torch.linalg.eigh(Z) |
| | phi_i = phi[:, -k:] * sigma[-k:] ** 0.25 |
| |
|
| | if use_hankel_L: |
| | logger.info("Mixing Hankel_L with Hankel_main to generate tensorized filters.") |
| | Z_L = get_hankel(sqrt_T_prime, True) |
| | sigma_L, phi_L = torch.linalg.eigh(Z_L) |
| | phi_j = phi_L[:, -k:] * sigma_L[-k:] ** 0.25 |
| | else: |
| | phi_j = phi_i |
| |
|
| | filters = torch.kron(phi_i, phi_j) |
| | return filters.to(device=device, dtype=dtype) |
| |
|