File size: 3,219 Bytes
cbda9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn

from .convolve import convolve, flash_convolve

try:
    from flashfftconv import FlashFFTConv

    flash_fft_available = True
except ImportError as e:
    print(
        f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation."
    )
    flash_fft_available = False


class STU(nn.Module):
    def __init__(self, config, phi, n) -> None:
        super(STU, self).__init__()
        self.config = config
        if isinstance(config.torch_dtype, str):
            torch_dtype = getattr(torch, config.torch_dtype)
        else:
            torch_dtype = config.torch_dtype
        self.phi = phi.to(device=config.device, dtype=torch_dtype)
        self.n = n
        self.K = config.num_eigh
        self.d_in = config.n_embd
        self.d_out = config.n_embd
        self.use_hankel_L = config.use_hankel_L
        self.use_approx = config.use_approx
        self.flash_fft = (
            FlashFFTConv(self.n, dtype=torch.bfloat16)
            if config.use_flash_fft and flash_fft_available
            else None
        )
        if self.use_approx:
            self.M_inputs = nn.Parameter(
                torch.empty(self.d_in, self.d_out, dtype=torch_dtype)
            )
            self.M_filters = nn.Parameter(
                torch.empty(self.K, self.d_in, dtype=torch_dtype)
            )
        else:
            self.M_phi_plus = nn.Parameter(
                torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype)
            )
            if not self.use_hankel_L:
                self.M_phi_minus = nn.Parameter(
                    torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype)
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        dtype = self.M_inputs.dtype
        x = x.to(dtype=dtype)
        if self.use_approx:
            # Contract inputs and filters over the K and d_in dimensions, then convolve
            x_proj = x @ self.M_inputs
            phi_proj = self.phi @ self.M_filters
            x_proj = x_proj.to(dtype=dtype)
            phi_proj = phi_proj.to(dtype=dtype)
            if self.flash_fft:
                spectral_plus, spectral_minus = flash_convolve(
                    x_proj, phi_proj, self.flash_fft, self.use_approx
                )
            else:
                spectral_plus, spectral_minus = convolve(
                    x_proj, phi_proj, self.n, self.use_approx
                )
        else:
            # Convolve inputs and filters,
            if self.flash_fft:
                U_plus, U_minus = flash_convolve(
                    x, self.phi, self.flash_fft, self.use_approx
                )
            else:
                U_plus, U_minus = convolve(x, self.phi, self.n, self.use_approx)
            # Then, contract over the K and d_in dimensions
            spectral_plus = torch.tensordot(
                U_plus, self.M_phi_plus, dims=([2, 3], [0, 1])
            )
            if not self.use_hankel_L:
                spectral_minus = torch.tensordot(
                    U_minus, self.M_phi_minus, dims=([2, 3], [0, 1])
                )

        return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus