|
|
|
|
|
import torch as th |
|
|
|
|
|
|
|
|
def discretize_zoh(Lambda, B_tilde, Delta): |
|
|
"""Discretize a diagonalized, continuous-time linear SSM |
|
|
using zero-order hold method. |
|
|
|
|
|
modified from: https://github.com/lindermanlab/S5/blob/3c18fdb6b06414da35e77b94b9cd855f6a95ef17/s5/ssm.py#L29 |
|
|
|
|
|
Args: |
|
|
Lambda (complex64): diagonal state matrix (P,) |
|
|
B_tilde (complex64): input matrix (P, H) |
|
|
Delta (float32): discretization step sizes (P,) |
|
|
Returns: |
|
|
discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) |
|
|
""" |
|
|
Identity = th.ones(Lambda.shape[0]) |
|
|
Lambda_bar = th.exp(Lambda * Delta) |
|
|
B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde |
|
|
return Lambda_bar, B_bar |
|
|
|
|
|
|
|
|
def apply_ssm( |
|
|
Lambda_bar_NP, |
|
|
B_bar_NPH, |
|
|
C_tilde_HP, |
|
|
input_sequence_NH, |
|
|
alpha_NP, |
|
|
conj_sym, |
|
|
initial_state_P=None, |
|
|
): |
|
|
"""Compute the NxH output of discretized SSM given an NxH input. |
|
|
|
|
|
modified from: https://github.com/lindermanlab/S5/blob/3c18fdb6b06414da35e77b94b9cd855f6a95ef17/s5/ssm.py#L60 |
|
|
- removed bidirectionality. |
|
|
- assume Lambda_bar is N-length. |
|
|
|
|
|
Args: |
|
|
Lambda_bar_NP (complex64): discretized diagonal state matrix for each interval (N, P) |
|
|
B_bar_NPH (complex64): "discretized" input matrix. Note: may be outside ZOH (N, P, H) |
|
|
C_tilde_HP (complex64): output matrix (H, P) |
|
|
input_sequence_NH (float32): input sequence of features (N, H) |
|
|
alpha_NP (complex64): mark-specific biases (N, P) |
|
|
conj_sym (bool): Whether conjugate symmetry is enforced |
|
|
initial_state_P (): Allow passing in a specific initial state (otherwise zero is assumed.) |
|
|
Returns: |
|
|
ys_NH (float32): the SSM outputs (S5 layer preactivations) (N, H) |
|
|
""" |
|
|
N, P, H = B_bar_NPH.shape |
|
|
|
|
|
|
|
|
Bu_elements_NP = th.vmap(lambda b, u, alpha: b @ u.type(th.complex64) + alpha)( |
|
|
B_bar_NPH, input_sequence_NH, alpha_NP |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if initial_state_P is None: |
|
|
state = th.zeros((P,)) |
|
|
else: |
|
|
state = initial_state_P |
|
|
|
|
|
|
|
|
|
|
|
xs = [state] |
|
|
|
|
|
for i, (lam_P, bu_P) in enumerate(zip(Lambda_bar_NP, Bu_elements_NP)): |
|
|
|
|
|
|
|
|
xs.append(lam_P * xs[-1] + bu_P) |
|
|
xs = th.stack(xs)[1:] |
|
|
|
|
|
|
|
|
if conj_sym: |
|
|
return xs, th.vmap(lambda x: 2 * (C_tilde_HP @ x).real)(xs) |
|
|
else: |
|
|
return xs, th.vmap(lambda x: (C_tilde_HP @ x).real)(xs) |
|
|
|