|
|
import torch |
|
|
from torch import nn |
|
|
import math |
|
|
|
|
|
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel |
|
|
|
|
|
class RMTPP(TorchBaseModel): |
|
|
"""Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016. |
|
|
https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf |
|
|
""" |
|
|
|
|
|
def __init__(self, model_config): |
|
|
"""Initialize the model |
|
|
|
|
|
Args: |
|
|
model_config (EasyTPP.ModelConfig): config of model specs. |
|
|
""" |
|
|
super(RMTPP, self).__init__(model_config) |
|
|
|
|
|
self.layer_temporal_emb = nn.Linear(1, self.hidden_size) |
|
|
self.layer_rnn = nn.RNN(input_size=self.hidden_size, hidden_size=self.hidden_size, |
|
|
num_layers=1, batch_first=True) |
|
|
|
|
|
self.hidden_to_intensity_logits = nn.Linear(self.hidden_size, self.num_event_types) |
|
|
self.b_t = nn.Parameter(torch.zeros(1, self.num_event_types)) |
|
|
self.w_t = nn.Parameter(torch.zeros(1, self.num_event_types)) |
|
|
nn.init.xavier_normal_(self.b_t) |
|
|
nn.init.xavier_normal_(self.w_t) |
|
|
|
|
|
def evolve_and_get_intentsity(self, right_hiddens_BNH, dts_BNG): |
|
|
""" |
|
|
Eq.11 that computes intensity. |
|
|
""" |
|
|
past_influence_BNGM = self.hidden_to_intensity_logits(right_hiddens_BNH[..., None, :]) |
|
|
intensity_BNGM = (past_influence_BNGM + self.w_t[None, None, :] * dts_BNG[..., None] |
|
|
+ self.b_t[None, None, :]).clamp(max=math.log(1e5)).exp() |
|
|
return intensity_BNGM |
|
|
|
|
|
def forward(self, batch): |
|
|
""" |
|
|
Suppose we have inputs with original sequence length N+1 |
|
|
ts: [t0, t1, ..., t_N] |
|
|
dts: [0, t1 - t0, t2 - t1, ..., t_N - t_{N-1}] |
|
|
marks: [k0, k1, ..., k_N] (k0 and kN could be padded marks if t0 and tN correspond to left and right windows) |
|
|
|
|
|
Return: |
|
|
left limits of *intensity* at [t_1, ..., t_N] of shape: (batch_size, seq_len - 1, hidden_dim) |
|
|
right limits of *hidden states* [t_0, ..., t_{N-1}, t_N] of shape: (batch_size, seq_len, hidden_dim) |
|
|
We need the right limit of t_N to sample continuation. |
|
|
""" |
|
|
|
|
|
t_BN, dt_BN, marks_BN, _, _ = batch |
|
|
mark_emb_BNH = self.layer_type_emb(marks_BN) |
|
|
time_emb_BNH = self.layer_temporal_emb(t_BN[..., None]) |
|
|
right_hiddens_BNH, _ = self.layer_rnn(mark_emb_BNH + time_emb_BNH) |
|
|
left_intensity_B_Nm1_M = self.evolve_and_get_intentsity(right_hiddens_BNH[:, :-1, :], dt_BN[:, 1:][...,None]).squeeze(-2) |
|
|
return left_intensity_B_Nm1_M, right_hiddens_BNH |
|
|
|
|
|
|
|
|
def loglike_loss(self, batch): |
|
|
"""Compute the log-likelihood loss. |
|
|
|
|
|
Args: |
|
|
batch (list): batch input. |
|
|
|
|
|
Returns: |
|
|
tuple: loglikelihood loss and num of events. |
|
|
""" |
|
|
ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
left_intensity_B_Nm1_M, right_hiddens_BNH = self.forward((ts_BN, dts_BN, marks_BN, None, None)) |
|
|
right_hiddens_B_Nm1_H = right_hiddens_BNH[..., :-1, :] |
|
|
|
|
|
dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:]) |
|
|
intensity_dts_B_Nm1_G_M = self.evolve_and_get_intentsity(right_hiddens_B_Nm1_H, dts_sample_B_Nm1_G) |
|
|
|
|
|
event_ll, non_event_ll, num_events = self.compute_loglikelihood( |
|
|
lambda_at_event=left_intensity_B_Nm1_M, |
|
|
lambdas_loss_samples=intensity_dts_B_Nm1_G_M, |
|
|
time_delta_seq=dts_BN[:, 1:], |
|
|
seq_mask=batch_non_pad_mask[:, 1:], |
|
|
type_seq=marks_BN[:, 1:] |
|
|
) |
|
|
|
|
|
|
|
|
loss = - (event_ll - non_event_ll).sum() |
|
|
return loss, num_events |
|
|
|
|
|
|
|
|
|
|
|
def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): |
|
|
"""Compute the intensity at sampled times, not only event times. |
|
|
|
|
|
Args: |
|
|
time_seq (tensor): [batch_size, seq_len], times seqs. |
|
|
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. |
|
|
event_seq (tensor): [batch_size, seq_len], event type seqs. |
|
|
sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. |
|
|
|
|
|
Returns: |
|
|
tensor: [batch_size, num_times, num_mc_sample, num_event_types], |
|
|
intensity at each timestamp for each event type. |
|
|
""" |
|
|
|
|
|
compute_last_step_only = kwargs.get('compute_last_step_only', False) |
|
|
|
|
|
_input = time_seqs, time_delta_seqs, type_seqs, None, None |
|
|
_, right_hiddens_BNH = self.forward(_input) |
|
|
|
|
|
if compute_last_step_only: |
|
|
sampled_intensities = self.evolve_and_get_intentsity(right_hiddens_BNH[:, -1:, :], sample_dtimes[:, -1:, :]) |
|
|
else: |
|
|
sampled_intensities = self.evolve_and_get_intentsity(right_hiddens_BNH, sample_dtimes) |
|
|
return sampled_intensities |
|
|
|