File size: 3,262 Bytes
cbda9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import math

import numpy as np
import torch

from .utils import logger
from .utils import get_hankel

def get_spectral_filters(
    seq_len: int,
    K: int,
    use_hankel_L: bool = False,
    device: torch.device = None,
    dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
    # Generate the Hankel matrix using PyTorch
    Z = get_hankel(seq_len, use_hankel_L, device=device, dtype=dtype)

    # Cast Z to torch.float32 for the eigenvalue decomposition
    Z_float32 = Z.to(torch.float32)

    # Perform eigen decomposition using torch.float32
    sigma, phi = torch.linalg.eigh(Z_float32)

    # Cast the results back to the original dtype (torch.bfloat16)
    sigma = sigma.to(dtype=dtype)
    phi = phi.to(dtype=dtype)

    # Select the top K eigenvalues and eigenvectors
    sigma_k, phi_k = sigma[-K:], phi[:, -K:]

    # Compute the spectral filters
    phi_k = phi_k * sigma_k ** 0.25

    # Ensure the filters are in the correct dtype and device
    filters = phi_k.to(device=device, dtype=dtype)

    return filters


def compute_dimensions(n: int) -> tuple[int, int, int]:
    if n <= 2:
        raise ValueError("n must be greater than 2")

    T_prime = (math.ceil(math.sqrt(n - 2)))**2 + 2
    sqrt_T_prime = math.ceil(math.sqrt(T_prime - 2))
    k_max = sqrt_T_prime
    return T_prime, sqrt_T_prime, k_max

def get_tensorized_spectral_filters_explicit(n: int, k: int, device: torch.device) -> torch.Tensor:
    T_prime, sqrt_T_prime, k_max = compute_dimensions(n)
    k = min(k, k_max)

    Z = get_hankel(sqrt_T_prime).to(device)
    sigma, phi = torch.linalg.eigh(Z)
    sigma_k = sigma[-k:]
    phi_k = phi[:, -k:]

    result = torch.zeros(sqrt_T_prime * sqrt_T_prime, device=device)
    
    for i in range(k):
        for j in range(k):
            phi_i = phi_k[:, i] * (sigma_k[i] ** 0.25)
            phi_j = phi_k[:, j] * (sigma_k[j] ** 0.25)
            kron = torch.kron(phi_i, phi_j)
            result += kron
            
    return result


def get_tensorized_spectral_filters(
    n: int = 8192,
    k: int = 24,
    use_hankel_L: bool = False,
    device: torch.device = None,
    dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
    """
    Compute tensorized spectral filters for given sequence length and filter count.

    Args:
        n: Sequence length
        k: Number of filters
        use_hankel_L: Hankel_main ⊗ Hankel_L? Default is Hankel_main ⊗ Hankel_main.
        device: Computation device
        dtype: Computation dtype
    """
    assert torch.cuda.is_available(), "CUDA is required."

    T_prime, sqrt_T_prime, k_max = compute_dimensions(n)
    k = min(k, k_max)

    Z = get_hankel(sqrt_T_prime)
    sigma, phi = torch.linalg.eigh(Z)
    phi_i = phi[:, -k:] * sigma[-k:] ** 0.25

    if use_hankel_L: # TODO: We may want to use Hankel_L above too if use_hankel_L is true, make another variable for this (mix != use_hankel_L)
        logger.info("Mixing Hankel_L with Hankel_main to generate tensorized filters.")
        Z_L = get_hankel(sqrt_T_prime, True)
        sigma_L, phi_L = torch.linalg.eigh(Z_L)
        phi_j = phi_L[:, -k:] * sigma_L[-k:] ** 0.25
    else:
        phi_j = phi_i

    filters = torch.kron(phi_i, phi_j)
    return filters.to(device=device, dtype=dtype)