from typing import Optional, Tuple import torch as th import torch.nn as nn import torch.nn.functional as F from .initializers import ( make_DPLR_HiPPO, # , lecun_normal_ # init_VinvB, init_log_steps, ) MATRIX_SCALING_FACTOR = 1 class LLH(nn.Module): """ This is canon: L -- number of layers N -- number of events. P -- Hidden dimension. Dimensionality of x. H -- output dimension. Dimensionality of y/u. """ def __init__( self, P: int, H: int, dt_init_min: float = 1e-4, dt_init_max: float = 0.1, dropout_rate: float = 0.0, act_func: str = "gelu", # F.gelu, for_loop: bool = False, pre_norm: bool = True, post_norm: bool = False, simple_mark: bool = True, is_first_layer: bool = False, relative_time: bool = False, complex_values: bool = True, ): """ :param P: :param H: :param dt_init_min: :param dt_init_max: :param act_func: """ super(LLH, self).__init__() # Inscribe the args. self.P = P self.H = H self.dt_init_min = dt_init_min self.dt_init_max = dt_init_max self.dropout_rate = dropout_rate self.complex_values = complex_values # select the activation function. if act_func == "gelu": self.act_func = nn.Sequential(nn.GELU(), nn.Dropout(p=self.dropout_rate)) elif act_func == "full_glu": self.act_func = nn.Sequential( nn.Linear(self.H, 2 * self.H), nn.Dropout(p=self.dropout_rate), nn.GLU(), nn.Dropout(p=self.dropout_rate), ) elif ( act_func == "half_glu" ): # ref: https://github.com/lindermanlab/S5/blob/main/s5/layers.py#L76 self.act_func1 = nn.Sequential( nn.Dropout(p=self.dropout_rate), nn.GELU(), nn.Linear(self.H, self.H), ) self.act_func = lambda x: nn.Dropout(p=self.dropout_rate)( x * nn.Sigmoid()(self.act_func1(x)) ) else: raise NotImplementedError( "Unrecognized activation function {}".format(act_func) ) # Assume we always use conjugate symmetry. self.conj_sym = True # Allow a learnable initial state. # Needs to be =/= 0 since we take the log to compute if self.complex_values: self.initial_state_P = nn.Parameter( th.complex( th.randn( self.P, ), th.randn( self.P, ), ) * 1e-3, requires_grad=True, ) else: self.initial_state_P = nn.Parameter( th.randn( self.P, ), requires_grad=True, ) self.norm = nn.LayerNorm(self.H) self.for_loop = for_loop self.pre_norm = pre_norm self.post_norm = post_norm self.is_first_layer = is_first_layer self.relative_time = relative_time self._init_ssm_params() self.simple_mark = simple_mark if not simple_mark: self.mark_a_net = nn.Linear(self.H, self.P, bias=True) self.mark_u_net = nn.Linear( self.H, self.P, bias=False ) # Only need one bias self.mark_a_net.weight.data = th.complex( nn.init.xavier_normal_(self.mark_a_net.weight.data) * 1e-3, nn.init.xavier_normal_(self.mark_a_net.weight.data) * 1e-3, ) self.mark_a_net.bias.data = th.complex( nn.init.xavier_normal_(self.mark_a_net.bias.data) * 1e-3, nn.init.xavier_normal_(self.mark_a_net.bias.data) * 1e-3, ) self.mark_u_net.weight.data = th.complex( nn.init.xavier_normal_(self.mark_u_net.weight.data) * 1e-3, nn.init.xavier_normal_(self.mark_u_net.weight.data) * 1e-3, ) if not self.complex_values: self.mark_a_net.weight.data = self.mark_a_net.weight.data.real self.mark_a_net.bias.data = self.mark_a_net.bias.data.real self.mark_u_net.weight.data = self.self.mark_u_net.weight.data.real def _init_ssm_params(self): self._init_A() if not self.is_first_layer: self._init_B() self._init_C() if ( not self.is_first_layer ): # Could group, but left in same order to not mess with initialization self._init_D() self._init_E() def _init_A(self): # Define the initial diagonal HiPPO matrix. # Te throw the HiPPO B away. Lambda_P, _, _, V_PP, _ = make_DPLR_HiPPO(self.P) self.Lambda_P_log_neg_real = th.nn.Parameter((-Lambda_P.real).log()) self.Lambda_P_imag = th.nn.Parameter(Lambda_P.imag) # Store these for use later. self._V_PP = V_PP self._Vc_PP = V_PP.conj().T # We also initialize the step size. if self.relative_time: self.delta_net = nn.Linear( self.H, self.P, bias=True ) # nn.Parameter(init_log_steps(self.P, self.dt_init_min, self.dt_init_max)) with th.no_grad(): self.delta_net.weight.copy_( nn.init.xavier_normal_(self.delta_net.weight) ) bias = th.ones( self.P, ) bias += th.log(-th.expm1(-bias)) self.delta_net.bias.copy_(bias) else: self.log_step_size_P = nn.Parameter( th.zeros(size=(self.P,)), requires_grad=False ) @property def Lambda_P(self): if self.complex_values: return th.complex( -self.Lambda_P_log_neg_real.exp(), self.Lambda_P_imag, ) else: return -self.Lambda_P_log_neg_real.exp() def _init_B(self): # Initialize the B outside the eigenbasis and then transform. B = nn.init.xavier_normal_(th.zeros((self.P, self.H))) * MATRIX_SCALING_FACTOR B_tilde_PH = self._Vc_PP @ B.type(th.complex64) self.B_tilde_PH = ( th.nn.Parameter(B_tilde_PH) if self.complex_values else th.nn.Parameter(B_tilde_PH.real) ) def _init_C(self): # Use the "complex_normal" initialization. # See ~https://github.com/lindermanlab/S5/blob/52cc7e22d6963459ad99a8674e4d3cfb0a480008/s5/ssm.py#L183 C = nn.init.xavier_normal_(th.zeros((self.H, self.P))) * MATRIX_SCALING_FACTOR C_tilde_HP = C.type(th.complex64) @ self._V_PP self.C_tilde_HP = ( th.nn.Parameter(C_tilde_HP) if self.complex_values else th.nn.Parameter(C_tilde_HP.real) ) # self.C_tilde_HP.data *= 1e-3 def _init_D(self): # Initialize feedthrough (D) matrix. Note the intensity depends on all layers. D_HH = th.zeros(self.H) nn.init.normal_(D_HH, std=1.0) self.D_HH = nn.Parameter(D_HH, requires_grad=True) def _init_E(self): E = ( th.nn.init.xavier_normal_(th.zeros((self.P, self.H))) * MATRIX_SCALING_FACTOR ) E_tilde_PH = self._Vc_PP @ E.type(th.complex64) self.E_tilde_PH = ( th.nn.Parameter(E_tilde_PH) if self.complex_values else th.nn.Parameter(E_tilde_PH.real) ) def compute_impulse(self, right_u_H, mark_embedding_H): # Compute impulse to add to left limit of x to make right limit. alpha_P = th.einsum( "ph,...h->...p", self.E_tilde_PH, mark_embedding_H.type(th.complex64) if self.complex_values else mark_embedding_H, ) return alpha_P def get_lambda(self, right_u_NH, shift_u=True): if self.relative_time and (right_u_NH is not None): if shift_u: # during "forward" when dts = [0, t1-t0, ..., t_N-t_{N-1}] right_u_NH = F.pad( right_u_NH[..., :-1, :], (0, 0, 1, 0) ) # pad default 0 at beginning of second to last dim lambda_rescaled_NP = ( F.softplus(self.delta_net(right_u_NH)) * self.Lambda_P ) # predict delta_i from right_u_i return {"lambda_rescaled_NP": lambda_rescaled_NP} else: if self.relative_time: lambda_rescaled_P = F.softplus(self.delta_net.bias) * self.Lambda_P else: lambda_rescaled_P = th.exp(self.log_step_size_P) * self.Lambda_P return {"lambda_rescaled_P": lambda_rescaled_P} def forward( self, left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` mark_embedding_NH: th.Tensor, dt_N: th.Tensor, initial_state_P: Optional[th.Tensor] = None, ) -> Tuple[th.Tensor, th.Tensor]: """ Apply the linear SSM to the inputs. In the context of TPPs, this returns the right limit of the "intensity function". This intensity will have been passed through a non-linearity, though, and so there is no guarantee for it is positive. :param u_NH: [..., seq_len, input_dim] :param alpha_NP: [..., seq_len, hidden_dim] :param dt_N: [..., seq_len] :param initial_state_P: [..., hidden_dim] :return: """ # Pull out the dimensions. *leading_dims, _, _ = mark_embedding_NH.shape num_leading_dims = len(leading_dims) if initial_state_P is None: # Pad and expand to match leading dimensions of input initial_state_P = self.initial_state_P.view( *[1 for _ in range(num_leading_dims)], -1 ).expand(*leading_dims, -1) # Add layer norm prime_left_u_NH = left_u_NH prime_right_u_NH = right_u_NH if prime_left_u_NH is not None: # ONLY for backward variant assert all( u_d == a_d for u_d, a_d in zip(prime_left_u_NH.shape, mark_embedding_NH.shape) ) # All but last dimensions should match if self.pre_norm: prime_left_u_NH = self.norm(prime_left_u_NH) if prime_right_u_NH is not None: assert all( u_d == a_d for u_d, a_d in zip(prime_right_u_NH.shape, mark_embedding_NH.shape) ) # All but last dimensions should match if self.pre_norm: prime_right_u_NH = self.norm(prime_right_u_NH) right_x_NP, left_y_NH, right_y_NH = self._ssm( left_u_NH=prime_left_u_NH, right_u_NH=prime_right_u_NH, impulse_NP=self.compute_impulse(prime_right_u_NH, mark_embedding_NH), dt_N=dt_N, initial_state_P=initial_state_P, ) # Given the following: # right_u: u0, u1, u2, ... <-> u_{t_0}, u_{t_1}, u_{t_2}, ... # left_u: u0, u1, u2, ... <-> u_{t_0-}, u_{t_1-}, u_{t_2-}, ... # a: a0, a1, a2, ... <-> mark embeddings for m_0, m_1, m_2, ... at times t_0, t_1, t_2 # dt: dt0, dt1, dt2, ... <-> 0, t_1-t_0, t_2-t_1, ... # initial_state_p: hidden state to evolve to to compute x_{0} # Returns the following: # right_x: x0, x1, x2, ... <-> x_{t_0}, x_{t_1}, x_{t_2}, ... # right_y: y0, y1, y2, ... <-> y_{t_0}, y_{t_1}, y_{t_2}, ... # left_y: y0, y1, y2, ... <-> y_{t_0-}, y_{t_1-}, y_{t_2-}, ... next_layer_left_u_NH = next_layer_right_u_NH = None if left_y_NH is not None: next_layer_left_u_NH = self.act_func(left_y_NH) + ( left_u_NH if left_u_NH is not None else 0.0 ) if self.post_norm: next_layer_left_u_NH = self.norm(next_layer_left_u_NH) if right_y_NH is not None: next_layer_right_u_NH = self.act_func(right_y_NH) + ( right_u_NH if right_u_NH is not None else 0.0 ) if self.post_norm: next_layer_right_u_NH = self.norm(next_layer_right_u_NH) return right_x_NP, next_layer_left_u_NH, next_layer_right_u_NH def _ssm( self, left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` impulse_NP: th.Tensor, dt_N: th.Tensor, # [0, t_1 - t_0, ..., t_N - t_{N-1}] initial_state_P: th.Tensor, ): *leading_dims, N, P = impulse_NP.shape u_NH = right_u_NH # This implementation does not use left_u, nor does it compute left_y if u_NH is not None: impulse_NP = impulse_NP + th.einsum( "ph,...nh->...np", self.B_tilde_PH, u_NH.type(th.complex64) if self.complex_values else u_NH, ) y_u_res_NH = th.einsum( "...nh,h->...nh", u_NH, self.D_HH ) # D_HH should really be D_H else: assert self.is_first_layer y_u_res_NH = 0.0 lambda_res = self.get_lambda(right_u_NH=right_u_NH, shift_u=True) if "lambda_rescaled_P" in lambda_res: # original formulation lambda_dt_NP = th.einsum( "...n,p->...np", dt_N, lambda_res["lambda_rescaled_P"] ) else: # relative time lambda_dt_NP = th.einsum( "...n,...np->...np", dt_N, lambda_res["lambda_rescaled_NP"] ) if self.for_loop: right_x_P = initial_state_P right_x_NP = [] for i in range(N): right_x_P = ( lambda_dt_NP[..., i, :].exp() * right_x_P + impulse_NP[..., i, :] ) right_x_NP.append(right_x_P) right_x_NP = th.stack(right_x_NP, dim=-2) else: # Trick inspired by: https://github.com/PeaBrane/mamba-tiny/blob/master/scans.py # .unsqueeze(-2) to add sequence dimension to initial state log_impulse_Np1_P = th.concat( (initial_state_P.unsqueeze(-2), impulse_NP), dim=-2 ).log() lamdba_dt_star = F.pad(lambda_dt_NP.cumsum(-2), (0, 0, 1, 0)) right_x_log_NP = ( th.logcumsumexp(log_impulse_Np1_P - lamdba_dt_star, -2) + lamdba_dt_star ) right_x_NP = right_x_log_NP.exp()[..., 1:, :] conj_sym_mult = 2 if self.conj_sym else 1 y_NH = ( conj_sym_mult * th.einsum("...np,hp->...nh", right_x_NP, self.C_tilde_HP).real + y_u_res_NH ) return right_x_NP, None, y_NH def get_left_limit( self, right_limit_P: th.Tensor, # Along with dt, can have any number of leading dimensions, produces a tensor of dim ...MP dt_G: th.Tensor, current_right_u_H: th.Tensor, next_left_u_GH: th.Tensor, ) -> th.Tensor: """ To get the left limit, we roll on the layer for the right dt. Computed for a single point (vmap for multiple). :param right_limit_P: at [t_0, ..., t_{N-1}] :param dt: Length of time to roll the layer on for. at [t_1 - t_0, ..., t_N - t_{N-1}] :param current_right_u_H: at [t_0, ..., t_{N-1}] -- for relative-time variant :param next_left_u_GH: at [t_1, ..., t_N] -- for backward variant :return: """ if current_right_u_H is not None and self.pre_norm: current_right_u_H = self.norm(current_right_u_H) lambda_res = self.get_lambda( current_right_u_H, shift_u=False ) # U should already be shifted if "lambda_rescaled_P" in lambda_res: lambda_bar_GP = th.exp( th.einsum("...g,p->...gp", dt_G, lambda_res["lambda_rescaled_P"]) ) else: lambda_bar_GP = th.exp( th.einsum("...g,...p->...gp", dt_G, lambda_res["lambda_rescaled_NP"]) ) return th.einsum("...p,...gp->...gp", right_limit_P, lambda_bar_GP) def depth_pass( self, current_left_x_P: th.Tensor, # No leading dimensions (seq, batch, etc.) here because we accommodate any of them current_left_u_H: Optional[ th.Tensor ], # Just assume that x and u match in the leading dimensions. Produces y_H with equivalent leading dimensions prev_right_u_H: Optional[ th.Tensor ], # Just assume that x and u match in the leading dimensions. Produces y_H with equivalent leading dimensions ) -> th.Tensor: if current_left_u_H is not None: if self.pre_norm: prime_u_H = self.norm(current_left_u_H) else: prime_u_H = current_left_u_H y_u_res_H = th.einsum( "...h,h->...h", prime_u_H, self.D_HH ) # D_HH should really be D_H else: assert self.is_first_layer y_u_res_H = 0.0 conj_sym_mult = 2 if self.conj_sym else 1 y_H = ( conj_sym_mult * th.einsum("...p,hp->...h", current_left_x_P, self.C_tilde_HP).real + y_u_res_H ) # Apply an activation function. if self.post_norm: new_u_H = self.norm( self.act_func(y_H) + (current_left_u_H if current_left_u_H is not None else 0.0) ) else: new_u_H = self.act_func(y_H) + ( current_left_u_H if current_left_u_H is not None else 0.0 ) return new_u_H class Int_Forward_LLH(LLH): # LLH but Bu_t is integrated w.r.t dt instead of dN_t # After discretization, when evolving x_t to x_t', applies ZOH on u_t over [t,t'] forward in time # (as opposed to u_{t'} backwards over [t,t']) def _ssm( self, left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` impulse_NP: th.Tensor, dt_N: th.Tensor, initial_state_P: th.Tensor, ) -> Tuple[th.Tensor, th.Tensor]: """ Apply the linear SSM to the inputs. In the context of TPPs, this returns the right limit of the "intensity function". This intensity will have been passed through a non-linearity, though, and so there is no guarantee for it is positive. :param u_NH: [..., seq_len, input_dim] :param alpha_NP: [..., seq_len, hidden_dim] :param dt_N: [..., seq_len] :param initial_state_P: [..., hidden_dim] :return: """ # Pull out the dimensions. *leading_dims, N, P = impulse_NP.shape lambda_res = self.get_lambda(right_u_NH=right_u_NH, shift_u=True) if "lambda_rescaled_P" in lambda_res: lambda_rescaled = lambda_res["lambda_rescaled_P"] lambda_dt_NP = th.einsum( "...n,p->...np", dt_N, lambda_res["lambda_rescaled_P"] ) else: lambda_rescaled = lambda_res["lambda_rescaled_NP"] lambda_dt_NP = th.einsum( "...n,...np->...np", dt_N, lambda_res["lambda_rescaled_NP"] ) if left_u_NH is not None: left_Du_NH = th.einsum( "...nh,h->...nh", left_u_NH, self.D_HH, ) else: assert self.is_first_layer left_Du_NH = 0.0 if right_u_NH is not None: right_u_NH = F.pad(right_u_NH[..., :-1, :], (0, 0, 1, 0)) right_Bu_NP = th.einsum( "...np,ph,...nh->...np", lambda_dt_NP.exp() - 1.0, # dts: [0, t1-t0, t2-t1, ...] self.B_tilde_PH, right_u_NH.type(th.complex64) if self.complex_values else right_u_NH, ) right_Du_NH = th.einsum( "...nh,h->...nh", right_u_NH, self.D_HH, ) else: assert self.is_first_layer right_Bu_NP = right_Du_NH = 0.0 if self.for_loop: right_x_P = initial_state_P left_x_NP, right_x_NP = [], [] for i in range(N): left_x_P = lambda_dt_NP[..., i, :].exp() * right_x_P + ( right_Bu_NP[..., i, :] if left_u_NH is not None else 0.0 ) right_x_P = left_x_P + impulse_NP[..., i, :] left_x_NP.append(left_x_P) right_x_NP.append(right_x_P) right_x_NP = th.stack( right_x_NP, dim=-2 ) # discard initial_hidden_states, right_limit of xs for [t0, t1, ...] left_x_NP = th.stack( left_x_NP, dim=-2 ) # discard initial_hidden_states, left_limit of xs for [t0, t1, ...] else: # Trick inspired by: https://github.com/PeaBrane/mamba-tiny/blob/master/scans.py # .unsqueeze(-2) to add sequence dimension to initial state log_impulse_Np1_P = th.concat( (initial_state_P.unsqueeze(-2), right_Bu_NP + impulse_NP), dim=-2 ).log() lamdba_dt_star = F.pad(lambda_dt_NP.cumsum(-2), (0, 0, 1, 0)) right_x_log_NP = ( th.logcumsumexp(log_impulse_Np1_P - lamdba_dt_star, -2) + lamdba_dt_star ) right_x_NP = right_x_log_NP.exp() # Contains initial_state_P in index 0 left_x_NP = ( lambda_dt_NP.exp() * right_x_NP[..., :-1, :] + right_Bu_NP ) # Evolves previous hidden state forward to compute left limit right_x_NP = right_x_NP[..., 1:, :] conj_sym_mult = 2 if self.conj_sym else 1 left_y_NH = ( conj_sym_mult * th.einsum("hp,...np->...nh", self.C_tilde_HP, left_x_NP).real + left_Du_NH ) # ys for [t0, t1, ...] right_y_NH = ( conj_sym_mult * th.einsum("hp,...np->...nh", self.C_tilde_HP, right_x_NP).real + right_Du_NH ) # ys for [t0, t1, ...] return right_x_NP, left_y_NH, right_y_NH def get_left_limit( self, right_limit_P: th.Tensor, # Along with dt, can have any number of leading dimensions, produces a tensor of dim ...MP dt_G: th.Tensor, current_right_u_H: Optional[th.Tensor], next_left_u_GH: Optional[th.Tensor], ) -> th.Tensor: """ To get the left limit, we roll on the layer for the right dt. Computed for a single point (vmap for multiple). :param right_limit_P: :param dt: Length of time to roll the layer on for. :return: """ if current_right_u_H is not None and self.pre_norm: current_right_u_H = self.norm(current_right_u_H) lambda_res = self.get_lambda( current_right_u_H, shift_u=False ) # U should already be shifted if "lambda_rescaled_P" in lambda_res: lambda_bar_GP = th.exp( th.einsum("...g,p->...gp", dt_G, lambda_res["lambda_rescaled_P"]) ) else: lambda_bar_GP = th.exp( th.einsum("...g,...p->...gp", dt_G, lambda_res["lambda_rescaled_NP"]) ) # lambda_rescaled_P = th.exp(self.log_step_size_P) * self.Lambda_P # lambda_bar_GP = th.exp(th.einsum('...g,p->...gp', dt_G, lambda_rescaled_P)) int_hidden_GP = th.einsum("...p,...gp->...gp", right_limit_P, lambda_bar_GP) if current_right_u_H is None: # no Bu term assert self.is_first_layer return int_hidden_GP else: # add Bu to impulse if self.pre_norm: current_right_u_H = self.norm(current_right_u_H) impulse_GP = th.einsum( "...gp,ph,...h->...gp", lambda_bar_GP - 1.0, self.B_tilde_PH, current_right_u_H.type(th.complex64) if self.complex_values else current_right_u_H, ) return int_hidden_GP + impulse_GP class Int_Backward_LLH(Int_Forward_LLH): # LLH but Bu_t is integrated w.r.t dt instead of dN_t # After discretization, when evolving x_t to x_t', applies ZOH on u_t' over [t,t'] backwards in time # (as opposed to u_{t} forwards over [t,t']) def _ssm( self, left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` impulse_NP: th.Tensor, dt_N: th.Tensor, initial_state_P: th.Tensor, ) -> Tuple[th.Tensor, th.Tensor]: """ Apply the linear SSM to the inputs. In the context of TPPs, this returns the right limit of the "intensity function". This intensity will have been passed through a non-linearity, though, and so there is no guarantee for it is positive. :param u_NH: [..., seq_len, input_dim] :param alpha_NP: [..., seq_len, hidden_dim] :param dt_N: [..., seq_len] :param initial_state_P: [..., hidden_dim] :return: """ # Pull out the dimensions. *leading_dims, N, P = impulse_NP.shape # lambda_rescaled_P = th.exp(self.log_step_size_P) * self.Lambda_P # lambda_dt_NP = th.einsum('...n,p->...np', dt_N, lambda_rescaled_P) lambda_res = self.get_lambda(right_u_NH=right_u_NH, shift_u=True) if "lambda_rescaled_P" in lambda_res: lambda_dt_NP = th.einsum( "...n,p->...np", dt_N, lambda_res["lambda_rescaled_P"] ) else: lambda_dt_NP = th.einsum( "...n,...np->...np", dt_N, lambda_res["lambda_rescaled_NP"] ) if left_u_NH is not None: left_Bu_NP = th.einsum( "...np,ph,...nh->...np", lambda_dt_NP.exp() - 1.0, # dts: [0, t1-t0, t2-t1, ...] self.B_tilde_PH, left_u_NH.type(th.complex64) if self.complex_values else left_u_NH, ) left_Du_NH = th.einsum( "...nh,h->...nh", left_u_NH, self.D_HH, ) else: assert self.is_first_layer left_Bu_NP = left_Du_NH = 0.0 if right_u_NH is not None: right_Du_NH = th.einsum( "...nh,h->...nh", right_u_NH, self.D_HH, ) else: assert self.is_first_layer right_Du_NH = 0.0 if self.for_loop: right_x_P = initial_state_P left_x_NP, right_x_NP = [], [] for i in range(N): left_x_P = lambda_dt_NP[..., i, :].exp() * right_x_P + ( left_Bu_NP[..., i, :] if left_u_NH is not None else 0.0 ) right_x_P = left_x_P + impulse_NP[..., i, :] left_x_NP.append(left_x_P) right_x_NP.append(right_x_P) right_x_NP = th.stack( right_x_NP, dim=-2 ) # discard initial_hidden_states, right_limit of xs for [t0, t1, ...] left_x_NP = th.stack( left_x_NP, dim=-2 ) # discard initial_hidden_states, left_limit of xs for [t0, t1, ...] else: # Trick inspired by: https://github.com/PeaBrane/mamba-tiny/blob/master/scans.py # .unsqueeze(-2) to add sequence dimension to initial state log_impulse_Np1_P = th.concat( (initial_state_P.unsqueeze(-2), left_Bu_NP + impulse_NP), dim=-2 ).log() lamdba_dt_star = F.pad(lambda_dt_NP.cumsum(-2), (0, 0, 1, 0)) right_x_log_NP = ( th.logcumsumexp(log_impulse_Np1_P - lamdba_dt_star, -2) + lamdba_dt_star ) right_x_NP = right_x_log_NP.exp() # Contains initial_state_P in index 0 left_x_NP = ( lambda_dt_NP.exp() * right_x_NP[..., :-1, :] + left_Bu_NP ) # Evolves previous hidden state forward to compute left limit right_x_NP = right_x_NP[..., 1:, :] conj_sym_mult = 2 if self.conj_sym else 1 left_y_NH = ( conj_sym_mult * th.einsum("hp,...np->...nh", self.C_tilde_HP, left_x_NP).real + left_Du_NH ) # ys for [t0, t1, ...] right_y_NH = ( conj_sym_mult * th.einsum("hp,...np->...nh", self.C_tilde_HP, right_x_NP).real + right_Du_NH ) # ys for [t0, t1, ...] return right_x_NP, left_y_NH, right_y_NH def get_left_limit( self, right_limit_P: th.Tensor, # Along with dt, can have any number of leading dimensions, produces a tensor of dim ...MP dt_G: th.Tensor, current_right_u_H: th.Tensor, next_left_u_GH: th.Tensor, ) -> th.Tensor: """ To get the left limit, we roll on the layer for the right dt. Computed for a single point (vmap for multiple). :param right_limit_P: :param dt: Length of time to roll the layer on for. :return: """ if current_right_u_H is not None and self.pre_norm: current_right_u_H = self.norm(current_right_u_H) lambda_res = self.get_lambda( current_right_u_H, shift_u=False ) # U should already be shifted if "lambda_rescaled_P" in lambda_res: lambda_bar_GP = th.exp( th.einsum("...g,p->...gp", dt_G, lambda_res["lambda_rescaled_P"]) ) else: lambda_bar_GP = th.exp( th.einsum("...g,...p->...gp", dt_G, lambda_res["lambda_rescaled_NP"]) ) int_hidden_GP = th.einsum("...p,...gp->...gp", right_limit_P, lambda_bar_GP) if next_left_u_GH is None: # no Bu term assert self.is_first_layer return int_hidden_GP else: # add Bu to impulse if self.pre_norm: next_left_u_GH = self.norm(next_left_u_GH) impulse_GP = th.einsum( "...gp,ph,...gh->...gp", lambda_bar_GP - 1.0, self.B_tilde_PH, next_left_u_GH.type(th.complex64) if self.complex_values else next_left_u_GH, ) return int_hidden_GP + impulse_GP