|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus |
|
|
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel |
|
|
from easy_tpp.ssm.models import LLH, Int_Backward_LLH, Int_Forward_LLH |
|
|
|
|
|
|
|
|
class ComplexEmbedding(nn.Module): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super(ComplexEmbedding, self).__init__() |
|
|
self.real_embedding = nn.Embedding(*args, **kwargs) |
|
|
self.imag_embedding = nn.Embedding(*args, **kwargs) |
|
|
|
|
|
self.real_embedding.weight.data *= 1e-3 |
|
|
self.imag_embedding.weight.data *= 1e-3 |
|
|
|
|
|
def forward(self, x): |
|
|
return torch.complex( |
|
|
self.real_embedding(x), |
|
|
self.imag_embedding(x), |
|
|
) |
|
|
|
|
|
|
|
|
class IntensityNet(nn.Module): |
|
|
def __init__(self, input_dim, bias, num_event_types): |
|
|
super().__init__() |
|
|
self.intensity_net = nn.Linear(input_dim, num_event_types, bias=bias) |
|
|
self.softplus = ScaledSoftplus(num_event_types) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.softplus(self.intensity_net(x)) |
|
|
|
|
|
|
|
|
class S2P2(TorchBaseModel): |
|
|
def __init__(self, model_config): |
|
|
"""Initialize the model |
|
|
|
|
|
Args: |
|
|
model_config (EasyTPP.ModelConfig): config of model specs. |
|
|
""" |
|
|
super(S2P2, self).__init__(model_config) |
|
|
self.n_layers = model_config.num_layers |
|
|
self.P = model_config.model_specs["P"] |
|
|
self.H = model_config.hidden_size |
|
|
self.beta = model_config.model_specs.get("beta", 1.0) |
|
|
self.bias = model_config.model_specs.get("bias", True) |
|
|
self.simple_mark = model_config.model_specs.get("simple_mark", True) |
|
|
|
|
|
layer_kwargs = dict( |
|
|
P=self.P, |
|
|
H=self.H, |
|
|
dt_init_min=model_config.model_specs.get("dt_init_min", 1e-4), |
|
|
dt_init_max=model_config.model_specs.get("dt_init_max", 0.1), |
|
|
act_func=model_config.model_specs.get("act_func", "full_glu"), |
|
|
dropout_rate=model_config.model_specs.get("dropout_rate", 0.0), |
|
|
for_loop=model_config.model_specs.get("for_loop", False), |
|
|
pre_norm=model_config.model_specs.get("pre_norm", True), |
|
|
post_norm=model_config.model_specs.get("post_norm", False), |
|
|
simple_mark=self.simple_mark, |
|
|
relative_time=model_config.model_specs.get("relative_time", False), |
|
|
complex_values=model_config.model_specs.get("complex_values", True), |
|
|
) |
|
|
|
|
|
int_forward_variant = model_config.model_specs.get("int_forward_variant", False) |
|
|
int_backward_variant = model_config.model_specs.get( |
|
|
"int_backward_variant", False |
|
|
) |
|
|
assert ( |
|
|
int_forward_variant + int_backward_variant |
|
|
) <= 1 |
|
|
|
|
|
if int_forward_variant: |
|
|
llh_layer = Int_Forward_LLH |
|
|
elif int_backward_variant: |
|
|
llh_layer = Int_Backward_LLH |
|
|
else: |
|
|
llh_layer = LLH |
|
|
|
|
|
self.backward_variant = int_backward_variant |
|
|
|
|
|
self.layers = nn.ModuleList( |
|
|
[ |
|
|
llh_layer(**layer_kwargs, is_first_layer=i == 0) |
|
|
for i in range(self.n_layers) |
|
|
] |
|
|
) |
|
|
self.layers_mark_emb = nn.Embedding( |
|
|
self.num_event_types_pad, |
|
|
self.H, |
|
|
) |
|
|
self.layer_type_emb = None |
|
|
self.intensity_net = IntensityNet( |
|
|
input_dim=self.H, |
|
|
bias=self.bias, |
|
|
num_event_types=self.num_event_types, |
|
|
) |
|
|
|
|
|
def _get_intensity( |
|
|
self, x_LP: Union[torch.tensor, List[torch.tensor]], right_us_BNH |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Assume time has already been evolved, take a vertical stack of hidden states and produce intensity. |
|
|
""" |
|
|
left_u_H = None |
|
|
for i, layer in enumerate(self.layers): |
|
|
if isinstance( |
|
|
x_LP, list |
|
|
): |
|
|
left_u_H = layer.depth_pass( |
|
|
x_LP[i], current_left_u_H=left_u_H, prev_right_u_H=right_us_BNH[i] |
|
|
) |
|
|
else: |
|
|
left_u_H = layer.depth_pass( |
|
|
x_LP[..., i, :], |
|
|
current_left_u_H=left_u_H, |
|
|
prev_right_u_H=right_us_BNH[i], |
|
|
) |
|
|
|
|
|
return self.intensity_net(left_u_H) |
|
|
|
|
|
def _evolve_and_get_intensity_at_sampled_dts(self, x_LP, dt_G, right_us_H): |
|
|
left_u_GH = None |
|
|
for i, layer in enumerate(self.layers): |
|
|
x_GP = layer.get_left_limit( |
|
|
right_limit_P=x_LP[..., i, :], |
|
|
dt_G=dt_G, |
|
|
next_left_u_GH=left_u_GH, |
|
|
current_right_u_H=right_us_H[i], |
|
|
) |
|
|
left_u_GH = layer.depth_pass( |
|
|
current_left_x_P=x_GP, |
|
|
current_left_u_H=left_u_GH, |
|
|
prev_right_u_H=right_us_H[i], |
|
|
) |
|
|
return self.intensity_net(left_u_GH) |
|
|
|
|
|
def forward( |
|
|
self, batch, initial_state_BLP: Optional[torch.Tensor] = None, **kwargs |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Batch operations of self._forward |
|
|
""" |
|
|
t_BN, dt_BN, marks_BN, batch_non_pad_mask, _ = batch |
|
|
|
|
|
right_xs_BNP = [] |
|
|
left_xs_BNm1P = [] |
|
|
right_us_BNH = [ |
|
|
None |
|
|
] |
|
|
left_u_BNH, right_u_BNH = None, None |
|
|
alpha_BNP = self.layers_mark_emb(marks_BN) |
|
|
|
|
|
for l_i, layer in enumerate(self.layers): |
|
|
|
|
|
init_state = ( |
|
|
initial_state_BLP[:, l_i] if initial_state_BLP is not None else None |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_BNP, next_layer_left_u_BNH, next_layer_right_u_BNH = layer.forward( |
|
|
left_u_BNH, right_u_BNH, alpha_BNP, dt_BN, init_state |
|
|
) |
|
|
assert next_layer_right_u_BNH is not None |
|
|
|
|
|
right_xs_BNP.append(x_BNP) |
|
|
if next_layer_left_u_BNH is None: |
|
|
left_xs_BNm1P.append( |
|
|
layer.get_left_limit( |
|
|
x_BNP[..., :-1, :], |
|
|
dt_BN[..., 1:].unsqueeze( |
|
|
-1 |
|
|
), |
|
|
current_right_u_H=right_u_BNH |
|
|
if right_u_BNH is None |
|
|
else right_u_BNH[ |
|
|
..., :-1, : |
|
|
], |
|
|
next_left_u_GH=left_u_BNH |
|
|
if left_u_BNH is None |
|
|
else left_u_BNH[..., 1:, :].unsqueeze( |
|
|
-2 |
|
|
), |
|
|
).squeeze(-2) |
|
|
) |
|
|
right_us_BNH.append(next_layer_right_u_BNH) |
|
|
|
|
|
left_u_BNH, right_u_BNH = next_layer_left_u_BNH, next_layer_right_u_BNH |
|
|
|
|
|
right_xs_BNLP = torch.stack(right_xs_BNP, dim=-2) |
|
|
|
|
|
ret_val = { |
|
|
"right_xs_BNLP": right_xs_BNLP, |
|
|
"right_us_BNH": right_us_BNH, |
|
|
} |
|
|
|
|
|
if left_u_BNH is not None: |
|
|
ret_val["left_u_BNm1H"] = left_u_BNH[ |
|
|
..., 1:, : |
|
|
] |
|
|
else: |
|
|
ret_val["left_xs_BNm1LP"] = torch.stack(left_xs_BNm1P, dim=-2) |
|
|
|
|
|
|
|
|
|
|
|
return ret_val |
|
|
|
|
|
def loglike_loss(self, batch, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
forward_results = self.forward( |
|
|
batch |
|
|
) |
|
|
right_xs_BNLP, right_us_BNH = ( |
|
|
forward_results["right_xs_BNLP"], |
|
|
forward_results["right_us_BNH"], |
|
|
) |
|
|
right_us_BNm1H = [ |
|
|
None if right_u_BNH is None else right_u_BNH[:, :-1, :] |
|
|
for right_u_BNH in right_us_BNH |
|
|
] |
|
|
|
|
|
ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "left_u_BNm1H" in forward_results: |
|
|
intensity_B_Nm1_M = self.intensity_net( |
|
|
forward_results["left_u_BNm1H"] |
|
|
) |
|
|
else: |
|
|
intensity_B_Nm1_M = self._get_intensity( |
|
|
forward_results["left_xs_BNm1LP"], right_us_BNm1H |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:]) |
|
|
|
|
|
|
|
|
intensity_dts_B_Nm1_G_M = self._evolve_and_get_intensity_at_sampled_dts( |
|
|
right_xs_BNLP[ |
|
|
:, :-1 |
|
|
], |
|
|
dts_sample_B_Nm1_G, |
|
|
right_us_BNm1H, |
|
|
) |
|
|
|
|
|
event_ll, non_event_ll, num_events = self.compute_loglikelihood( |
|
|
lambda_at_event=intensity_B_Nm1_M, |
|
|
lambdas_loss_samples=intensity_dts_B_Nm1_G_M, |
|
|
time_delta_seq=dts_BN[:, 1:], |
|
|
seq_mask=batch_non_pad_mask[:, 1:], |
|
|
type_seq=marks_BN[:, 1:], |
|
|
) |
|
|
|
|
|
|
|
|
loss = -(event_ll - non_event_ll).sum() |
|
|
|
|
|
return loss, num_events |
|
|
|
|
|
def compute_intensities_at_sample_times( |
|
|
self, event_times_BN, inter_event_times_BN, marks_BN, sample_dtimes, **kwargs |
|
|
): |
|
|
"""Compute the intensity at sampled times, not only event times. *from the left limit* |
|
|
|
|
|
Args: |
|
|
time_seq (tensor): [batch_size, seq_len], times seqs. |
|
|
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. |
|
|
event_seq (tensor): [batch_size, seq_len], event type seqs. |
|
|
sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. |
|
|
|
|
|
Returns: |
|
|
tensor: [batch_size, num_times, num_mc_sample, num_event_types], |
|
|
intensity at each timestamp for each event type. |
|
|
""" |
|
|
|
|
|
compute_last_step_only = kwargs.get("compute_last_step_only", False) |
|
|
|
|
|
|
|
|
_input = event_times_BN, inter_event_times_BN, marks_BN, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
forward_results = self.forward( |
|
|
_input |
|
|
) |
|
|
right_xs_BNLP, right_us_BNH = ( |
|
|
forward_results["right_xs_BNLP"], |
|
|
forward_results["right_us_BNH"], |
|
|
) |
|
|
|
|
|
if ( |
|
|
compute_last_step_only |
|
|
): |
|
|
right_us_B1H = [ |
|
|
None if right_u_BNH is None else right_u_BNH[:, -1:, :] |
|
|
for right_u_BNH in right_us_BNH |
|
|
] |
|
|
sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( |
|
|
right_xs_BNLP[:, -1:, :, :], sample_dtimes[:, -1:, :], right_us_B1H |
|
|
) |
|
|
else: |
|
|
sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( |
|
|
right_xs_BNLP, sample_dtimes, right_us_BNH |
|
|
) |
|
|
return sampled_intensity |
|
|
|