|
|
import torch |
|
|
import torch.nn as nn |
|
|
from easy_tpp.utils import logger |
|
|
|
|
|
|
|
|
class EventSampler(nn.Module): |
|
|
"""Event Sequence Sampler based on thinning algorithm, which corresponds to Algorithm 2 of |
|
|
The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process, |
|
|
https://arxiv.org/abs/1612.09328. |
|
|
|
|
|
The implementation uses code from https://github.com/yangalan123/anhp-andtt/blob/master/anhp/esm/thinning.py. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_sample, num_exp, over_sample_rate, num_samples_boundary, dtime_max, patience_counter, |
|
|
device): |
|
|
"""Initialize the event sampler. |
|
|
|
|
|
Args: |
|
|
num_sample (int): number of sampled next event times via thinning algo for computing predictions. |
|
|
num_exp (int): number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm |
|
|
over_sample_rate (float): multiplier for the intensity up bound. |
|
|
num_samples_boundary (int): number of sampled event times to compute the boundary of the intensity. |
|
|
dtime_max (float): max value of delta times in sampling |
|
|
patience_counter (int): the maximum iteration used in adaptive thinning. |
|
|
device (torch.device): torch device index to select. |
|
|
""" |
|
|
super(EventSampler, self).__init__() |
|
|
self.num_sample = num_sample |
|
|
self.num_exp = num_exp |
|
|
self.over_sample_rate = over_sample_rate |
|
|
self.num_samples_boundary = num_samples_boundary |
|
|
self.dtime_max = dtime_max |
|
|
self.patience_counter = patience_counter |
|
|
self.device = device |
|
|
|
|
|
def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, intensity_fn, |
|
|
compute_last_step_only): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Compute the upper bound of intensity at each event timestamp. |
|
|
|
|
|
Args: |
|
|
time_seq (tensor): [batch_size, seq_len], timestamp seqs. |
|
|
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. |
|
|
event_seq (tensor): [batch_size, seq_len], event type seqs. |
|
|
intensity_fn (fn): a function that computes the intensity. |
|
|
compute_last_step_only (bool): wheter to compute the last time step pnly. |
|
|
|
|
|
Returns: |
|
|
tensor: [batch_size, seq_len] |
|
|
""" |
|
|
batch_size, seq_len = time_seq.size() |
|
|
|
|
|
|
|
|
time_for_bound_sampled = torch.linspace(start=0.0, |
|
|
end=1.0, |
|
|
steps=self.num_samples_boundary, |
|
|
device=self.device)[None, None, :] |
|
|
|
|
|
|
|
|
dtime_for_bound_sampled = time_delta_seq[:, :, None] * time_for_bound_sampled |
|
|
|
|
|
|
|
|
intensities_for_bound = intensity_fn(time_seq, |
|
|
time_delta_seq, |
|
|
event_seq, |
|
|
dtime_for_bound_sampled, |
|
|
max_steps=seq_len, |
|
|
compute_last_step_only=compute_last_step_only) |
|
|
|
|
|
|
|
|
bounds = intensities_for_bound.sum(dim=-1).max(dim=-1)[0] * self.over_sample_rate |
|
|
|
|
|
return bounds |
|
|
|
|
|
def sample_exp_distribution(self, sample_rate): |
|
|
"""Sample an exponential distribution. |
|
|
|
|
|
Args: |
|
|
sample_rate (tensor): [batch_size, seq_len], intensity rate. |
|
|
|
|
|
Returns: |
|
|
tensor: [batch_size, seq_len, num_exp], exp numbers at each event timestamp. |
|
|
""" |
|
|
|
|
|
batch_size, seq_len = sample_rate.size() |
|
|
|
|
|
|
|
|
|
|
|
exp_numbers = torch.empty(size=[batch_size, seq_len, self.num_exp], |
|
|
dtype=torch.float32, |
|
|
device=self.device) |
|
|
|
|
|
|
|
|
|
|
|
exp_numbers.exponential_(1.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exp_numbers = exp_numbers / sample_rate[:, :, None] |
|
|
|
|
|
return exp_numbers |
|
|
|
|
|
def sample_uniform_distribution(self, intensity_upper_bound): |
|
|
"""Sample an uniform distribution |
|
|
|
|
|
Args: |
|
|
intensity_upper_bound (tensor): upper bound intensity computed in the previous step. |
|
|
|
|
|
Returns: |
|
|
tensor: [batch_size, seq_len, num_sample, num_exp] |
|
|
""" |
|
|
batch_size, seq_len = intensity_upper_bound.size() |
|
|
|
|
|
unif_numbers = torch.empty(size=[batch_size, seq_len, self.num_sample, self.num_exp], |
|
|
dtype=torch.float32, |
|
|
device=self.device) |
|
|
unif_numbers.uniform_(0.0, 1.0) |
|
|
|
|
|
return unif_numbers |
|
|
|
|
|
def sample_accept(self, unif_numbers, sample_rate, total_intensities, exp_numbers): |
|
|
"""Do the sample-accept process. |
|
|
|
|
|
For the accumulated exp (delta) samples drawn for each event timestamp, find (from left to right) the first |
|
|
that makes the criterion < 1 and accept it as the sampled next-event time. If all exp samples are rejected |
|
|
(criterion >= 1), then we set the sampled next-event time dtime_max. |
|
|
|
|
|
Args: |
|
|
unif_numbers (tensor): [batch_size, max_len, num_sample, num_exp], sampled uniform random number. |
|
|
sample_rate (tensor): [batch_size, max_len], sample rate (intensity). |
|
|
total_intensities (tensor): [batch_size, seq_len, num_sample, num_exp] |
|
|
exp_numbers (tensor): [batch_size, seq_len, num_sample, num_exp]: sampled exp numbers (delta in Algorithm 2). |
|
|
|
|
|
Returns: |
|
|
result (tensor): [batch_size, seq_len, num_sample], sampled next-event times. |
|
|
""" |
|
|
|
|
|
|
|
|
criterion = unif_numbers * sample_rate[:, :, None, None] / total_intensities |
|
|
|
|
|
|
|
|
masked_crit_less_than_1 = torch.where(criterion<1,1,0) |
|
|
|
|
|
|
|
|
non_accepted_filter = (1-masked_crit_less_than_1).all(dim=3) |
|
|
|
|
|
|
|
|
first_accepted_indexer = masked_crit_less_than_1.argmax(dim=3) |
|
|
|
|
|
|
|
|
|
|
|
result_non_accepted_unfiltered = torch.gather(exp_numbers, 3, first_accepted_indexer.unsqueeze(3)) |
|
|
|
|
|
|
|
|
result = torch.where(non_accepted_filter.unsqueeze(3), torch.tensor(self.dtime_max), result_non_accepted_unfiltered) |
|
|
|
|
|
|
|
|
result = result.squeeze(dim=-1) |
|
|
|
|
|
return result |
|
|
|
|
|
def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_boundary, |
|
|
intensity_fn, compute_last_step_only=False): |
|
|
"""Compute next event time based on Thinning algorithm. |
|
|
|
|
|
Args: |
|
|
time_seq (tensor): [batch_size, seq_len], timestamp seqs. |
|
|
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. |
|
|
event_seq (tensor): [batch_size, seq_len], event type seqs. |
|
|
dtime_boundary (tensor): [batch_size, seq_len], dtime upper bound. |
|
|
intensity_fn (fn): a function to compute the intensity. |
|
|
compute_last_step_only (bool, optional): whether to compute last event timestep only. Defaults to False. |
|
|
|
|
|
Returns: |
|
|
tuple: next event time prediction and weight. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
intensity_upper_bound = self.compute_intensity_upper_bound(time_seq, |
|
|
time_delta_seq, |
|
|
event_seq, |
|
|
intensity_fn, |
|
|
compute_last_step_only) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exp_numbers = self.sample_exp_distribution(intensity_upper_bound) |
|
|
exp_numbers = torch.cumsum(exp_numbers, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
intensities_at_sampled_times = intensity_fn(time_seq, |
|
|
time_delta_seq, |
|
|
event_seq, |
|
|
exp_numbers, |
|
|
max_steps=time_seq.size(1), |
|
|
compute_last_step_only=compute_last_step_only) |
|
|
|
|
|
|
|
|
total_intensities = intensities_at_sampled_times.sum(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
total_intensities = torch.tile(total_intensities[:, :, None, :], [1, 1, self.num_sample, 1]) |
|
|
|
|
|
|
|
|
exp_numbers = torch.tile(exp_numbers[:, :, None, :], [1, 1, self.num_sample, 1]) |
|
|
|
|
|
|
|
|
|
|
|
unif_numbers = self.sample_uniform_distribution(intensity_upper_bound) |
|
|
|
|
|
|
|
|
|
|
|
res = self.sample_accept(unif_numbers, intensity_upper_bound, total_intensities, exp_numbers) |
|
|
|
|
|
|
|
|
weights = torch.ones_like(res)/res.shape[2] |
|
|
|
|
|
|
|
|
return res.clamp(max=1e5), weights |
|
|
|