|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, TimePositionalEncoding, ScaledSoftplus |
|
|
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel |
|
|
|
|
|
|
|
|
class THP(TorchBaseModel): |
|
|
"""Torch implementation of Transformer Hawkes Process, ICML 2020, https://arxiv.org/abs/2002.09291. |
|
|
Note: Part of the code is collected from https://github.com/yangalan123/anhp-andtt/tree/master/thp. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_config): |
|
|
"""Initialize the model |
|
|
|
|
|
Args: |
|
|
model_config (EasyTPP.ModelConfig): config of model specs. |
|
|
""" |
|
|
super(THP, self).__init__(model_config) |
|
|
self.d_model = model_config.hidden_size |
|
|
self.d_time = model_config.time_emb_size |
|
|
self.use_norm = model_config.use_ln |
|
|
|
|
|
self.n_layers = model_config.num_layers |
|
|
self.n_head = model_config.num_heads |
|
|
self.dropout = model_config.dropout_rate |
|
|
|
|
|
self.layer_temporal_encoding = TimePositionalEncoding(self.d_model, device=self.device) |
|
|
|
|
|
self.factor_intensity_base = nn.Parameter(torch.empty([1, self.num_event_types], device=self.device)) |
|
|
self.factor_intensity_decay = nn.Parameter(torch.empty([1, self.num_event_types], device=self.device)) |
|
|
nn.init.xavier_normal_(self.factor_intensity_base) |
|
|
nn.init.xavier_normal_(self.factor_intensity_decay) |
|
|
|
|
|
|
|
|
self.layer_intensity_hidden = nn.Linear(self.d_model, self.num_event_types) |
|
|
self.softplus = ScaledSoftplus(self.num_event_types) |
|
|
|
|
|
|
|
|
|
|
|
self.feed_forward = nn.Sequential( |
|
|
nn.Linear(self.d_model, self.d_model * 2), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self.d_model * 2, self.d_model) |
|
|
) |
|
|
|
|
|
self.stack_layers = nn.ModuleList( |
|
|
[EncoderLayer( |
|
|
self.d_model, |
|
|
MultiHeadAttention(self.n_head, self.d_model, self.d_model, self.dropout, |
|
|
output_linear=False), |
|
|
use_residual=False, |
|
|
feed_forward=self.feed_forward, |
|
|
dropout=self.dropout |
|
|
) for _ in range(self.n_layers)]) |
|
|
|
|
|
def forward(self, time_seqs, type_seqs, attention_mask): |
|
|
"""Call the model |
|
|
|
|
|
Args: |
|
|
time_seqs (tensor): [batch_size, seq_len], timestamp seqs. |
|
|
type_seqs (tensor): [batch_size, seq_len], event type seqs. |
|
|
attention_mask (tensor): [batch_size, seq_len, hidden_size], attention masks. |
|
|
|
|
|
Returns: |
|
|
tensor: hidden states at event times. |
|
|
""" |
|
|
|
|
|
tem_enc = self.layer_temporal_encoding(time_seqs) |
|
|
enc_output = self.layer_type_emb(type_seqs) |
|
|
|
|
|
|
|
|
for enc_layer in self.stack_layers: |
|
|
enc_output += tem_enc |
|
|
enc_output = enc_layer( |
|
|
enc_output, |
|
|
mask=attention_mask) |
|
|
|
|
|
return enc_output |
|
|
|
|
|
def loglike_loss(self, batch): |
|
|
"""Compute the loglike loss. |
|
|
|
|
|
Args: |
|
|
batch (tuple, list): batch input. |
|
|
|
|
|
Returns: |
|
|
tuple: loglike loss, num events. |
|
|
""" |
|
|
time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask = batch |
|
|
|
|
|
|
|
|
|
|
|
enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, :-1, :-1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
factor_intensity_decay = self.factor_intensity_decay[None, ...] |
|
|
factor_intensity_base = self.factor_intensity_base[None, ...] |
|
|
|
|
|
|
|
|
|
|
|
intensity_states = factor_intensity_decay * time_delta_seqs[:, 1:, None] + self.layer_intensity_hidden( |
|
|
enc_out) + factor_intensity_base |
|
|
|
|
|
lambda_at_event = self.softplus(intensity_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample_dtimes = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) |
|
|
|
|
|
|
|
|
|
|
|
state_t_sample = self.compute_states_at_sample_times(event_states=enc_out, |
|
|
sample_dtimes=sample_dtimes) |
|
|
lambda_t_sample = self.softplus(state_t_sample) |
|
|
|
|
|
event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, |
|
|
lambdas_loss_samples=lambda_t_sample, |
|
|
time_delta_seq=time_delta_seqs[:, 1:], |
|
|
seq_mask=batch_non_pad_mask[:, 1:], |
|
|
type_seq=type_seqs[:, 1:]) |
|
|
|
|
|
|
|
|
loss = - (event_ll - non_event_ll).sum() |
|
|
return loss, num_events |
|
|
|
|
|
def compute_states_at_sample_times(self, event_states, sample_dtimes): |
|
|
"""Compute the hidden states at sampled times. |
|
|
|
|
|
Args: |
|
|
event_states (tensor): [batch_size, seq_len, hidden_size]. |
|
|
sample_dtimes (tensor): [batch_size, seq_len, num_samples]. |
|
|
|
|
|
Returns: |
|
|
tensor: hidden state at each sampled time. |
|
|
""" |
|
|
|
|
|
event_states = event_states[:, :, None, :] |
|
|
|
|
|
|
|
|
sample_dtimes = sample_dtimes[..., None] |
|
|
|
|
|
|
|
|
factor_intensity_decay = self.factor_intensity_decay[None, None, ...] |
|
|
factor_intensity_base = self.factor_intensity_base[None, None, ...] |
|
|
|
|
|
|
|
|
|
|
|
intensity_states = factor_intensity_decay * sample_dtimes + self.layer_intensity_hidden( |
|
|
event_states) + factor_intensity_base |
|
|
|
|
|
return intensity_states |
|
|
|
|
|
def compute_intensities_at_sample_times(self, |
|
|
time_seqs, |
|
|
time_delta_seqs, |
|
|
type_seqs, |
|
|
sample_dtimes, |
|
|
**kwargs): |
|
|
"""Compute hidden states at sampled times. |
|
|
|
|
|
Args: |
|
|
time_seqs (tensor): [batch_size, seq_len], times seqs. |
|
|
time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. |
|
|
type_seqs (tensor): [batch_size, seq_len], event type seqs. |
|
|
sample_dtimes (tensor): [batch_size, seq_len, num_samples], sampled inter-event timestamps. |
|
|
|
|
|
Returns: |
|
|
tensor: [batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times. |
|
|
""" |
|
|
|
|
|
attention_mask = kwargs.get('attention_mask', None) |
|
|
compute_last_step_only = kwargs.get('compute_last_step_only', False) |
|
|
|
|
|
if attention_mask is None: |
|
|
batch_size, seq_len = time_seqs.size() |
|
|
attention_mask = torch.triu(torch.ones(seq_len, seq_len, device=self.device), diagonal=1).unsqueeze(0) |
|
|
attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) |
|
|
|
|
|
|
|
|
enc_out = self.forward(time_seqs, type_seqs, attention_mask) |
|
|
|
|
|
|
|
|
encoder_output = self.compute_states_at_sample_times(enc_out, sample_dtimes) |
|
|
|
|
|
if compute_last_step_only: |
|
|
lambdas = self.softplus(encoder_output[:, -1:, :, :]) |
|
|
else: |
|
|
|
|
|
lambdas = self.softplus(encoder_output) |
|
|
return lambdas |
|
|
|