from typing import List, Optional, Tuple, Union import torch from torch import nn from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel from easy_tpp.ssm.models import LLH, Int_Backward_LLH, Int_Forward_LLH class ComplexEmbedding(nn.Module): def __init__(self, *args, **kwargs): super(ComplexEmbedding, self).__init__() self.real_embedding = nn.Embedding(*args, **kwargs) self.imag_embedding = nn.Embedding(*args, **kwargs) self.real_embedding.weight.data *= 1e-3 self.imag_embedding.weight.data *= 1e-3 def forward(self, x): return torch.complex( self.real_embedding(x), self.imag_embedding(x), ) class IntensityNet(nn.Module): def __init__(self, input_dim, bias, num_event_types): super().__init__() self.intensity_net = nn.Linear(input_dim, num_event_types, bias=bias) self.softplus = ScaledSoftplus(num_event_types) def forward(self, x): return self.softplus(self.intensity_net(x)) class S2P2(TorchBaseModel): def __init__(self, model_config): """Initialize the model Args: model_config (EasyTPP.ModelConfig): config of model specs. """ super(S2P2, self).__init__(model_config) self.n_layers = model_config.num_layers self.P = model_config.model_specs["P"] # Hidden state dimension self.H = model_config.hidden_size # Residual stream dimension self.beta = model_config.model_specs.get("beta", 1.0) self.bias = model_config.model_specs.get("bias", True) self.simple_mark = model_config.model_specs.get("simple_mark", True) layer_kwargs = dict( P=self.P, H=self.H, dt_init_min=model_config.model_specs.get("dt_init_min", 1e-4), dt_init_max=model_config.model_specs.get("dt_init_max", 0.1), act_func=model_config.model_specs.get("act_func", "full_glu"), dropout_rate=model_config.model_specs.get("dropout_rate", 0.0), for_loop=model_config.model_specs.get("for_loop", False), pre_norm=model_config.model_specs.get("pre_norm", True), post_norm=model_config.model_specs.get("post_norm", False), simple_mark=self.simple_mark, relative_time=model_config.model_specs.get("relative_time", False), complex_values=model_config.model_specs.get("complex_values", True), ) int_forward_variant = model_config.model_specs.get("int_forward_variant", False) int_backward_variant = model_config.model_specs.get( "int_backward_variant", False ) assert ( int_forward_variant + int_backward_variant ) <= 1 # Only one at most is allowed to be specified if int_forward_variant: llh_layer = Int_Forward_LLH elif int_backward_variant: llh_layer = Int_Backward_LLH else: llh_layer = LLH self.backward_variant = int_backward_variant self.layers = nn.ModuleList( [ llh_layer(**layer_kwargs, is_first_layer=i == 0) for i in range(self.n_layers) ] ) self.layers_mark_emb = nn.Embedding( self.num_event_types_pad, self.H, ) # One embedding to share amongst layers to be used as input into a layer-specific and input-aware impulse self.layer_type_emb = None # Remove old embeddings from EasyTPP self.intensity_net = IntensityNet( input_dim=self.H, bias=self.bias, num_event_types=self.num_event_types, ) def _get_intensity( self, x_LP: Union[torch.tensor, List[torch.tensor]], right_us_BNH ) -> torch.Tensor: """ Assume time has already been evolved, take a vertical stack of hidden states and produce intensity. """ left_u_H = None for i, layer in enumerate(self.layers): if isinstance( x_LP, list ): # Sometimes it is convenient to pass as a list over the layers rather than a single tensor left_u_H = layer.depth_pass( x_LP[i], current_left_u_H=left_u_H, prev_right_u_H=right_us_BNH[i] ) else: left_u_H = layer.depth_pass( x_LP[..., i, :], current_left_u_H=left_u_H, prev_right_u_H=right_us_BNH[i], ) return self.intensity_net(left_u_H) # self.ScaledSoftplus(self.linear(left_u_H)) def _evolve_and_get_intensity_at_sampled_dts(self, x_LP, dt_G, right_us_H): left_u_GH = None for i, layer in enumerate(self.layers): x_GP = layer.get_left_limit( right_limit_P=x_LP[..., i, :], dt_G=dt_G, next_left_u_GH=left_u_GH, current_right_u_H=right_us_H[i], ) left_u_GH = layer.depth_pass( current_left_x_P=x_GP, current_left_u_H=left_u_GH, prev_right_u_H=right_us_H[i], ) return self.intensity_net(left_u_GH) # self.ScaledSoftplus(self.linear(left_u_GH)) def forward( self, batch, initial_state_BLP: Optional[torch.Tensor] = None, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor]: """ Batch operations of self._forward """ t_BN, dt_BN, marks_BN, batch_non_pad_mask, _ = batch right_xs_BNP = [] # including both t_0 and t_N left_xs_BNm1P = [] right_us_BNH = [ None ] # Start with None as this is the 'input' to the first layer left_u_BNH, right_u_BNH = None, None alpha_BNP = self.layers_mark_emb(marks_BN) for l_i, layer in enumerate(self.layers): # for each event, compute the fixed impulse via alpha_m for event i of type m init_state = ( initial_state_BLP[:, l_i] if initial_state_BLP is not None else None ) # Returns right limit of xs and us for [t0, t1, ..., tN] # "layer" returns the right limit of xs at current layer, and us for the next layer (as transformations of ys) # x_BNP: at time [t_0, t_1, ..., t_{N-1}, t_N] # next_left_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- only available for backward variant # next_right_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- always returned but only used for RT x_BNP, next_layer_left_u_BNH, next_layer_right_u_BNH = layer.forward( left_u_BNH, right_u_BNH, alpha_BNP, dt_BN, init_state ) assert next_layer_right_u_BNH is not None right_xs_BNP.append(x_BNP) if next_layer_left_u_BNH is None: # NOT backward variant left_xs_BNm1P.append( layer.get_left_limit( # current and next at event level x_BNP[..., :-1, :], # at time [t_0, t_1, ..., t_{N-1}] dt_BN[..., 1:].unsqueeze( -1 ), # with dts [t1-t0, t2-t1, ..., t_N-t_{N-1}] current_right_u_H=right_u_BNH if right_u_BNH is None else right_u_BNH[ ..., :-1, : ], # at time [t_0, t_1, ..., t_{N-1}] next_left_u_GH=left_u_BNH if left_u_BNH is None else left_u_BNH[..., 1:, :].unsqueeze( -2 ), # at time [t_1, t_2 ..., t_N] ).squeeze(-2) ) right_us_BNH.append(next_layer_right_u_BNH) left_u_BNH, right_u_BNH = next_layer_left_u_BNH, next_layer_right_u_BNH right_xs_BNLP = torch.stack(right_xs_BNP, dim=-2) ret_val = { "right_xs_BNLP": right_xs_BNLP, # [t_0, ..., t_N] "right_us_BNH": right_us_BNH, # [t_0, ..., t_N]; list starting with None } if left_u_BNH is not None: # backward variant ret_val["left_u_BNm1H"] = left_u_BNH[ ..., 1:, : ] # The next inputs after last layer -> transformation of ys else: # NOT backward variant ret_val["left_xs_BNm1LP"] = torch.stack(left_xs_BNm1P, dim=-2) # 'seq_len - 1' left limit for [t_1, ..., t_N] for events (u if available, x if not) # 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] for events xs or us return ret_val def loglike_loss(self, batch, **kwargs): # hidden states at the left and right limits around event time; note for the shift by 1 in indices: # consider a sequence [t0, t1, ..., tN] # Produces the following: # left_x: x0, x1, x2, ... <-> x_{t_1-}, x_{t_2-}, x_{t_3-}, ..., x_{t_N-} (note the shift in indices) for all layers # OR ==> <-> u_{t_1-}, u_{t_2-}, u_{t_3-}, ..., u_{t_N-} for last layer # # right_x: x0, x1, x2, ... <-> x_{t_0+}, x_{t_1+}, ..., x_{t_N+} for all layers # right_u: u0, u1, u2, ... <-> u_{t_0+}, u_{t_1+}, ..., u_{t_N+} for all layers forward_results = self.forward( batch ) # N minus 1 comparing with sequence lengths right_xs_BNLP, right_us_BNH = ( forward_results["right_xs_BNLP"], forward_results["right_us_BNH"], ) right_us_BNm1H = [ None if right_u_BNH is None else right_u_BNH[:, :-1, :] for right_u_BNH in right_us_BNH ] ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch # evaluate intensity values at each event *from the left limit*, _get_intensity: [LP] -> [M] # left_xs_B_Nm1_LP = left_xs_BNm1LP[:, :-1, ...] # discard the left limit of t_N # Note: no need to discard the left limit of t_N because "marks_mask" will deal with it if "left_u_BNm1H" in forward_results: # ONLY backward variant intensity_B_Nm1_M = self.intensity_net( forward_results["left_u_BNm1H"] ) # self.ScaledSoftplus(self.linear(forward_results["left_u_BNm1H"])) else: # NOT backward variant intensity_B_Nm1_M = self._get_intensity( forward_results["left_xs_BNm1LP"], right_us_BNm1H ) # sample dt in each interval for MC: [batch_size, num_times=N-1, num_mc_sample] # N-1 because we only consider the intervals between N events # G for grid points dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:]) # evaluate intensity at dt_samples for MC *from the left limit* after decay -> shape (B, N-1, MC, M) intensity_dts_B_Nm1_G_M = self._evolve_and_get_intensity_at_sampled_dts( right_xs_BNLP[ :, :-1 ], # x_{t_i+} will evolve up to x_{t_{i+1}-} and many times between for i=0,...,N-1 dts_sample_B_Nm1_G, right_us_BNm1H, ) event_ll, non_event_ll, num_events = self.compute_loglikelihood( lambda_at_event=intensity_B_Nm1_M, lambdas_loss_samples=intensity_dts_B_Nm1_G_M, time_delta_seq=dts_BN[:, 1:], seq_mask=batch_non_pad_mask[:, 1:], type_seq=marks_BN[:, 1:], ) # compute loss to optimize loss = -(event_ll - non_event_ll).sum() return loss, num_events def compute_intensities_at_sample_times( self, event_times_BN, inter_event_times_BN, marks_BN, sample_dtimes, **kwargs ): """Compute the intensity at sampled times, not only event times. *from the left limit* Args: time_seq (tensor): [batch_size, seq_len], times seqs. time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. event_seq (tensor): [batch_size, seq_len], event type seqs. sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. Returns: tensor: [batch_size, num_times, num_mc_sample, num_event_types], intensity at each timestamp for each event type. """ compute_last_step_only = kwargs.get("compute_last_step_only", False) # assume inter_event_times_BN always starts from 0 _input = event_times_BN, inter_event_times_BN, marks_BN, None, None # 'seq_len - 1' left limit for [t_1, ..., t_N] # 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] forward_results = self.forward( _input ) # N minus 1 comparing with sequence lengths right_xs_BNLP, right_us_BNH = ( forward_results["right_xs_BNLP"], forward_results["right_us_BNH"], ) if ( compute_last_step_only ): # fix indices for right_us_BNH: list [None, tensor([BNH]), ...] right_us_B1H = [ None if right_u_BNH is None else right_u_BNH[:, -1:, :] for right_u_BNH in right_us_BNH ] sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( right_xs_BNLP[:, -1:, :, :], sample_dtimes[:, -1:, :], right_us_B1H ) # equiv. to right_xs_BNLP[:, -1, :, :][:, None, ...] else: sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( right_xs_BNLP, sample_dtimes, right_us_BNH ) return sampled_intensity # [B, N, MC, M]