File size: 3,219 Bytes
f43af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import math

import numpy as np
import numpy as onp
import torch as th
from numpy.linalg import eigh


def make_HiPPO(P):
    """Create a HiPPO-LegS matrix.
    From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
    Args:
        P (int32): state size
    Returns:
        P x P HiPPO LegS matrix
    """
    M = np.sqrt(1 + 2 * np.arange(P))
    A = M[:, np.newaxis] * M[np.newaxis, :]
    A = np.tril(A) - np.diag(np.arange(P))
    return -A


def make_NPLR_HiPPO(P):
    """
    Makes components needed for NPLR representation of HiPPO-LegS
     From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
    Args:
        P (int32): state size

    Returns:
        P x P HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B

    """
    # Make -HiPPO
    hippo = make_HiPPO(P)

    # Add in a rank 1 term. Makes it Normal.
    R1 = np.sqrt(np.arange(P) + 0.5)

    # HiPPO also specifies the B matrix
    B = np.sqrt(2 * np.arange(P) + 1.0)
    return hippo, R1, B


def make_DPLR_HiPPO(P):
    """
    Makes components needed for DPLR representation of HiPPO-LegS
     From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
    Note, we will only use the diagonal part
    Args:
        P:

    Returns:
        eigenvalues Lambda, low-rank term R1, conjugated HiPPO input matrix B,
        eigenvectors V, HiPPO B pre-conjugation

    """
    A, R1, B = make_NPLR_HiPPO(P)

    S = A + R1[:, np.newaxis] * R1[np.newaxis, :]

    S_diag = np.diagonal(S)
    Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)

    # Diagonalize S to V \Lambda V^*
    Lambda_imag, V = eigh(S * -1j)

    R1 = V.conj().T @ R1
    B_orig = B
    B = V.conj().T @ B
    return (
        th.tensor(onp.asarray(Lambda_real + 1j * Lambda_imag), dtype=th.complex64),
        th.tensor(onp.asarray(R1)),
        th.tensor(onp.asarray(B)),
        th.tensor(onp.asarray(V), dtype=th.complex64),
        th.tensor(onp.asarray(B_orig)),
    )


def init_log_steps(P, dt_min, dt_max):
    """Initialize an array of learnable timescale parameters.
    initialized uniformly in log space.
     Args:
         input:
     Returns:
         initialized array of timescales (float32): (P,)
    """
    unlog = th.rand(size=(P,))
    log = unlog * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
    return log


def lecun_normal_(tensor: th.Tensor) -> th.Tensor:
    input_size = tensor.shape[
        -1
    ]  # Assuming that the weights' input dimension is the last.
    std = math.sqrt(1 / input_size)
    with th.no_grad():
        return tensor.normal_(0, std)  # or torch.nn.init.xavier_normal_


def init_VinvB(shape, Vinv):
    """Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B.
    Note we will parameterize this with two different matrices for complex

    Modified from https://github.com/lindermanlab/S5/blob/52cc7e22d6963459ad99a8674e4d3cfb0a480008/s5/ssm.py#L165

    numbers.
     Args:
         shape (tuple): desired shape  (P,H)
         Vinv: (complex64)     the inverse eigenvectors used for initialization
     Returns:
         B_tilde (complex64) of shape (P,H)
    """
    B = lecun_normal_(th.zeros(shape))
    VinvB = Vinv @ B.type(th.complex64)
    return VinvB