|
|
import math |
|
|
|
|
|
import numpy as np |
|
|
import numpy as onp |
|
|
import torch as th |
|
|
from numpy.linalg import eigh |
|
|
|
|
|
|
|
|
def make_HiPPO(P): |
|
|
"""Create a HiPPO-LegS matrix. |
|
|
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py |
|
|
Args: |
|
|
P (int32): state size |
|
|
Returns: |
|
|
P x P HiPPO LegS matrix |
|
|
""" |
|
|
M = np.sqrt(1 + 2 * np.arange(P)) |
|
|
A = M[:, np.newaxis] * M[np.newaxis, :] |
|
|
A = np.tril(A) - np.diag(np.arange(P)) |
|
|
return -A |
|
|
|
|
|
|
|
|
def make_NPLR_HiPPO(P): |
|
|
""" |
|
|
Makes components needed for NPLR representation of HiPPO-LegS |
|
|
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py |
|
|
Args: |
|
|
P (int32): state size |
|
|
|
|
|
Returns: |
|
|
P x P HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B |
|
|
|
|
|
""" |
|
|
|
|
|
hippo = make_HiPPO(P) |
|
|
|
|
|
|
|
|
R1 = np.sqrt(np.arange(P) + 0.5) |
|
|
|
|
|
|
|
|
B = np.sqrt(2 * np.arange(P) + 1.0) |
|
|
return hippo, R1, B |
|
|
|
|
|
|
|
|
def make_DPLR_HiPPO(P): |
|
|
""" |
|
|
Makes components needed for DPLR representation of HiPPO-LegS |
|
|
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py |
|
|
Note, we will only use the diagonal part |
|
|
Args: |
|
|
P: |
|
|
|
|
|
Returns: |
|
|
eigenvalues Lambda, low-rank term R1, conjugated HiPPO input matrix B, |
|
|
eigenvectors V, HiPPO B pre-conjugation |
|
|
|
|
|
""" |
|
|
A, R1, B = make_NPLR_HiPPO(P) |
|
|
|
|
|
S = A + R1[:, np.newaxis] * R1[np.newaxis, :] |
|
|
|
|
|
S_diag = np.diagonal(S) |
|
|
Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) |
|
|
|
|
|
|
|
|
Lambda_imag, V = eigh(S * -1j) |
|
|
|
|
|
R1 = V.conj().T @ R1 |
|
|
B_orig = B |
|
|
B = V.conj().T @ B |
|
|
return ( |
|
|
th.tensor(onp.asarray(Lambda_real + 1j * Lambda_imag), dtype=th.complex64), |
|
|
th.tensor(onp.asarray(R1)), |
|
|
th.tensor(onp.asarray(B)), |
|
|
th.tensor(onp.asarray(V), dtype=th.complex64), |
|
|
th.tensor(onp.asarray(B_orig)), |
|
|
) |
|
|
|
|
|
|
|
|
def init_log_steps(P, dt_min, dt_max): |
|
|
"""Initialize an array of learnable timescale parameters. |
|
|
initialized uniformly in log space. |
|
|
Args: |
|
|
input: |
|
|
Returns: |
|
|
initialized array of timescales (float32): (P,) |
|
|
""" |
|
|
unlog = th.rand(size=(P,)) |
|
|
log = unlog * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) |
|
|
return log |
|
|
|
|
|
|
|
|
def lecun_normal_(tensor: th.Tensor) -> th.Tensor: |
|
|
input_size = tensor.shape[ |
|
|
-1 |
|
|
] |
|
|
std = math.sqrt(1 / input_size) |
|
|
with th.no_grad(): |
|
|
return tensor.normal_(0, std) |
|
|
|
|
|
|
|
|
def init_VinvB(shape, Vinv): |
|
|
"""Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. |
|
|
Note we will parameterize this with two different matrices for complex |
|
|
|
|
|
Modified from https://github.com/lindermanlab/S5/blob/52cc7e22d6963459ad99a8674e4d3cfb0a480008/s5/ssm.py#L165 |
|
|
|
|
|
numbers. |
|
|
Args: |
|
|
shape (tuple): desired shape (P,H) |
|
|
Vinv: (complex64) the inverse eigenvectors used for initialization |
|
|
Returns: |
|
|
B_tilde (complex64) of shape (P,H) |
|
|
""" |
|
|
B = lecun_normal_(th.zeros(shape)) |
|
|
VinvB = Vinv @ B.type(th.complex64) |
|
|
return VinvB |
|
|
|