Abigail99216's picture
Upload folder using huggingface_hub
f43af3c verified
import torch
import torch.nn as nn
from easy_tpp.utils import logger
class EventSampler(nn.Module):
"""Event Sequence Sampler based on thinning algorithm, which corresponds to Algorithm 2 of
The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process,
https://arxiv.org/abs/1612.09328.
The implementation uses code from https://github.com/yangalan123/anhp-andtt/blob/master/anhp/esm/thinning.py.
"""
def __init__(self, num_sample, num_exp, over_sample_rate, num_samples_boundary, dtime_max, patience_counter,
device):
"""Initialize the event sampler.
Args:
num_sample (int): number of sampled next event times via thinning algo for computing predictions.
num_exp (int): number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
over_sample_rate (float): multiplier for the intensity up bound.
num_samples_boundary (int): number of sampled event times to compute the boundary of the intensity.
dtime_max (float): max value of delta times in sampling
patience_counter (int): the maximum iteration used in adaptive thinning.
device (torch.device): torch device index to select.
"""
super(EventSampler, self).__init__()
self.num_sample = num_sample
self.num_exp = num_exp
self.over_sample_rate = over_sample_rate
self.num_samples_boundary = num_samples_boundary
self.dtime_max = dtime_max
self.patience_counter = patience_counter
self.device = device
def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, intensity_fn,
compute_last_step_only):
# logger.critical(f'time_seq: {time_seq}')
# logger.critical(f'time_delta_seq: {time_delta_seq}')
# logger.critical(f'event_seq: {event_seq}')
# logger.critical(f'intensity_fn: {intensity_fn}')
# logger.critical(f'compute_last_step_only: {compute_last_step_only}')
"""Compute the upper bound of intensity at each event timestamp.
Args:
time_seq (tensor): [batch_size, seq_len], timestamp seqs.
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs.
event_seq (tensor): [batch_size, seq_len], event type seqs.
intensity_fn (fn): a function that computes the intensity.
compute_last_step_only (bool): wheter to compute the last time step pnly.
Returns:
tensor: [batch_size, seq_len]
"""
batch_size, seq_len = time_seq.size()
# [1, 1, num_samples_boundary]
time_for_bound_sampled = torch.linspace(start=0.0,
end=1.0,
steps=self.num_samples_boundary,
device=self.device)[None, None, :]
# [batch_size, seq_len, num_samples_boundary]
dtime_for_bound_sampled = time_delta_seq[:, :, None] * time_for_bound_sampled
# [batch_size, seq_len, num_samples_boundary, event_num]
intensities_for_bound = intensity_fn(time_seq,
time_delta_seq,
event_seq,
dtime_for_bound_sampled,
max_steps=seq_len,
compute_last_step_only=compute_last_step_only)
# [batch_size, seq_len]
bounds = intensities_for_bound.sum(dim=-1).max(dim=-1)[0] * self.over_sample_rate
return bounds
def sample_exp_distribution(self, sample_rate):
"""Sample an exponential distribution.
Args:
sample_rate (tensor): [batch_size, seq_len], intensity rate.
Returns:
tensor: [batch_size, seq_len, num_exp], exp numbers at each event timestamp.
"""
batch_size, seq_len = sample_rate.size()
# For fast approximation, we reuse the rnd for all samples
# [batch_size, seq_len, num_exp]
exp_numbers = torch.empty(size=[batch_size, seq_len, self.num_exp],
dtype=torch.float32,
device=self.device)
# [batch_size, seq_len, num_exp]
# exp_numbers.exponential_(1.0)
exp_numbers.exponential_(1.0)
# [batch_size, seq_len, num_exp]
# exp_numbers = torch.tile(exp_numbers, [1, 1, self.num_sample, 1])
# [batch_size, seq_len, num_exp]
# div by sample_rate is equivalent to exp(sample_rate),
# see https://en.wikipedia.org/wiki/Exponential_distribution
exp_numbers = exp_numbers / sample_rate[:, :, None]
return exp_numbers
def sample_uniform_distribution(self, intensity_upper_bound):
"""Sample an uniform distribution
Args:
intensity_upper_bound (tensor): upper bound intensity computed in the previous step.
Returns:
tensor: [batch_size, seq_len, num_sample, num_exp]
"""
batch_size, seq_len = intensity_upper_bound.size()
unif_numbers = torch.empty(size=[batch_size, seq_len, self.num_sample, self.num_exp],
dtype=torch.float32,
device=self.device)
unif_numbers.uniform_(0.0, 1.0)
return unif_numbers
def sample_accept(self, unif_numbers, sample_rate, total_intensities, exp_numbers):
"""Do the sample-accept process.
For the accumulated exp (delta) samples drawn for each event timestamp, find (from left to right) the first
that makes the criterion < 1 and accept it as the sampled next-event time. If all exp samples are rejected
(criterion >= 1), then we set the sampled next-event time dtime_max.
Args:
unif_numbers (tensor): [batch_size, max_len, num_sample, num_exp], sampled uniform random number.
sample_rate (tensor): [batch_size, max_len], sample rate (intensity).
total_intensities (tensor): [batch_size, seq_len, num_sample, num_exp]
exp_numbers (tensor): [batch_size, seq_len, num_sample, num_exp]: sampled exp numbers (delta in Algorithm 2).
Returns:
result (tensor): [batch_size, seq_len, num_sample], sampled next-event times.
"""
# [batch_size, max_len, num_sample, num_exp]
criterion = unif_numbers * sample_rate[:, :, None, None] / total_intensities
# [batch_size, max_len, num_sample, num_exp]
masked_crit_less_than_1 = torch.where(criterion<1,1,0)
# [batch_size, max_len, num_sample]
non_accepted_filter = (1-masked_crit_less_than_1).all(dim=3)
# [batch_size, max_len, num_sample]
first_accepted_indexer = masked_crit_less_than_1.argmax(dim=3)
# [batch_size, max_len, num_sample,1]
# indexer must be unsqueezed to 4D to match the number of dimensions of exp_numbers
result_non_accepted_unfiltered = torch.gather(exp_numbers, 3, first_accepted_indexer.unsqueeze(3))
# [batch_size, max_len, num_sample,1]
result = torch.where(non_accepted_filter.unsqueeze(3), torch.tensor(self.dtime_max), result_non_accepted_unfiltered)
# [batch_size, max_len, num_sample]
result = result.squeeze(dim=-1)
return result
def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_boundary,
intensity_fn, compute_last_step_only=False):
"""Compute next event time based on Thinning algorithm.
Args:
time_seq (tensor): [batch_size, seq_len], timestamp seqs.
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs.
event_seq (tensor): [batch_size, seq_len], event type seqs.
dtime_boundary (tensor): [batch_size, seq_len], dtime upper bound.
intensity_fn (fn): a function to compute the intensity.
compute_last_step_only (bool, optional): whether to compute last event timestep only. Defaults to False.
Returns:
tuple: next event time prediction and weight.
"""
# 1. compute the upper bound of the intensity at each timestamp
# the last event has no label (no next event), so we drop it
# [batch_size, seq_len=max_len - 1]
intensity_upper_bound = self.compute_intensity_upper_bound(time_seq,
time_delta_seq,
event_seq,
intensity_fn,
compute_last_step_only)
# 2. draw exp distribution with intensity = intensity_upper_bound
# we apply fast approximation, i.e., re-use exp sample times for computation
# [batch_size, seq_len, num_exp]
exp_numbers = self.sample_exp_distribution(intensity_upper_bound)
exp_numbers = torch.cumsum(exp_numbers, dim=-1)
# 3. compute intensity at sampled times from exp distribution
# [batch_size, seq_len, num_exp, event_num]
intensities_at_sampled_times = intensity_fn(time_seq,
time_delta_seq,
event_seq,
exp_numbers,
max_steps=time_seq.size(1),
compute_last_step_only=compute_last_step_only)
# [batch_size, seq_len, num_exp]
total_intensities = intensities_at_sampled_times.sum(dim=-1)
# add one dim of num_sample: re-use the intensity for samples for prediction
# [batch_size, seq_len, num_sample, num_exp]
total_intensities = torch.tile(total_intensities[:, :, None, :], [1, 1, self.num_sample, 1])
# [batch_size, seq_len, num_sample, num_exp]
exp_numbers = torch.tile(exp_numbers[:, :, None, :], [1, 1, self.num_sample, 1])
# 4. draw uniform distribution
# [batch_size, seq_len, num_sample, num_exp]
unif_numbers = self.sample_uniform_distribution(intensity_upper_bound)
# 5. find out accepted intensities
# [batch_size, seq_len, num_sample]
res = self.sample_accept(unif_numbers, intensity_upper_bound, total_intensities, exp_numbers)
# [batch_size, seq_len, num_sample]
weights = torch.ones_like(res)/res.shape[2]
# add a upper bound here in case it explodes, e.g., in ODE models
return res.clamp(max=1e5), weights