|
|
import random |
|
|
from typing import Optional, Union, Dict, Any |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from easy_tpp.config_factory import DataSpecConfig, Config |
|
|
from easy_tpp.model import TorchNHP as NHP |
|
|
from easy_tpp.preprocess import TPPDataset, EventTokenizer |
|
|
from easy_tpp.preprocess.data_collator import TPPDataCollator |
|
|
from easy_tpp.preprocess.event_tokenizer import BatchEncoding |
|
|
from easy_tpp.utils import PaddingStrategy |
|
|
|
|
|
|
|
|
def make_raw_data(): |
|
|
data = [ |
|
|
[{"time_since_last_event": 0, "time_since_start": 0, "type_event": 0, 'loan_amt': 10}], |
|
|
[{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 10}], |
|
|
[{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 20}], |
|
|
[{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 20}], |
|
|
[{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 20}], |
|
|
[{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 30}], |
|
|
] |
|
|
for i, j in enumerate([2, 5, 3, 2, 4, 2]): |
|
|
start_time = 0 |
|
|
for k in range(j): |
|
|
delta_t = random.random() |
|
|
start_time += delta_t |
|
|
data[i].append({"time_since_last_event": delta_t, |
|
|
"time_since_start": start_time, |
|
|
"type_event": random.randint(0, 10), |
|
|
'loan_amt': random.randint(10, 30)}) |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
class TPPDatasetV2(TPPDataset): |
|
|
def __init__(self, data): |
|
|
super(TPPDatasetV2, self).__init__(data) |
|
|
self.loan_amt_seqs = self.data_dict['loan_amt_seqs'] |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
idx: iteration index |
|
|
|
|
|
Returns: |
|
|
dict: a dict of time_seqs, time_delta_seqs and type_seqs element |
|
|
|
|
|
""" |
|
|
return dict({'time_seqs': self.time_seqs[idx], 'time_delta_seqs': self.time_delta_seqs[idx], |
|
|
'type_seqs': self.type_seqs[idx], 'loan_amt_seqs': self.loan_amt_seqs[idx]}) |
|
|
|
|
|
|
|
|
class EventTokenizerV2(EventTokenizer): |
|
|
def __init__(self, config): |
|
|
super(EventTokenizerV2, self).__init__(config) |
|
|
self.model_input_names.append('loan_amt_seqs') |
|
|
self.model_input_names.append('type_mask') |
|
|
|
|
|
def _pad( |
|
|
self, |
|
|
encoded_inputs: Union[Dict[str, Any], BatchEncoding], |
|
|
max_length: Optional[int] = None, |
|
|
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, |
|
|
return_attention_mask: Optional[bool] = None, |
|
|
) -> dict: |
|
|
""" |
|
|
Pad encoded inputs (on left/right and up to predefined length or max length in the batch) |
|
|
|
|
|
Args: |
|
|
encoded_inputs: |
|
|
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). |
|
|
max_length: maximum length of the returned list and optionally padding length (see below). |
|
|
Will truncate by taking into account the special tokens. |
|
|
padding_strategy: PaddingStrategy to use for padding. |
|
|
|
|
|
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch |
|
|
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default) |
|
|
- PaddingStrategy.DO_NOT_PAD: Do not pad |
|
|
The tokenizer padding sides are defined in self.padding_side: |
|
|
|
|
|
- 'left': pads on the left of the sequences |
|
|
- 'right': pads on the right of the sequences |
|
|
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. |
|
|
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability |
|
|
`>= 7.5` (Volta). |
|
|
return_attention_mask: |
|
|
(optional) Set to False to avoid returning attention mask (default: set to model specifics) |
|
|
""" |
|
|
|
|
|
if return_attention_mask is None: |
|
|
return_attention_mask = "attention_mask" in self.model_input_names |
|
|
|
|
|
required_input = encoded_inputs[self.model_input_names[0]] |
|
|
|
|
|
if padding_strategy == PaddingStrategy.LONGEST: |
|
|
max_length = len(required_input) |
|
|
|
|
|
|
|
|
is_all_seq_equal_max_length = [len(seq) == max_length for seq in required_input] |
|
|
is_all_seq_equal_max_length = np.prod(is_all_seq_equal_max_length) |
|
|
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and ~is_all_seq_equal_max_length |
|
|
|
|
|
batch_output = dict() |
|
|
|
|
|
if needs_to_be_padded: |
|
|
|
|
|
batch_output[self.model_input_names[0]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[0]], |
|
|
self.pad_token_id, |
|
|
padding_side=self.padding_side, |
|
|
max_len=max_length) |
|
|
|
|
|
batch_output[self.model_input_names[1]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[1]], |
|
|
self.pad_token_id, |
|
|
padding_side=self.padding_side, |
|
|
max_len=max_length) |
|
|
|
|
|
batch_output[self.model_input_names[2]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[2]], |
|
|
self.pad_token_id, |
|
|
padding_side=self.padding_side, |
|
|
max_len=max_length, |
|
|
dtype=np.int32) |
|
|
|
|
|
else: |
|
|
batch_output = encoded_inputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_pad_mask = batch_output[self.model_input_names[2]] == self.pad_token_id |
|
|
batch_output[self.model_input_names[3]] = ~ seq_pad_mask |
|
|
|
|
|
if return_attention_mask: |
|
|
|
|
|
batch_output[self.model_input_names[4]] = self.make_attn_mask_for_pad_sequence( |
|
|
batch_output[self.model_input_names[2]], |
|
|
self.pad_token_id) |
|
|
else: |
|
|
batch_output[self.model_input_names[4]] = [] |
|
|
|
|
|
|
|
|
batch_output[self.model_input_names[6]] = self.make_type_mask_for_pad_sequence( |
|
|
batch_output[self.model_input_names[2]]) |
|
|
|
|
|
|
|
|
batch_output[self.model_input_names[5]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[-2]], |
|
|
self.pad_token_id, |
|
|
padding_side=self.padding_side, |
|
|
max_len=max_length) |
|
|
|
|
|
return batch_output |
|
|
|
|
|
|
|
|
def make_data_loader(): |
|
|
source_data = make_raw_data() |
|
|
|
|
|
time_seqs = [[x["time_since_start"] for x in seq] for seq in source_data] |
|
|
type_seqs = [[x["type_event"] for x in seq] for seq in source_data] |
|
|
time_delta_seqs = [[x["time_since_last_event"] for x in seq] for seq in source_data] |
|
|
loan_amt_seqs = [[x["loan_amt"] for x in seq] for seq in source_data] |
|
|
|
|
|
input_data = {'time_seqs': time_seqs, |
|
|
'type_seqs': type_seqs, |
|
|
'time_delta_seqs': time_delta_seqs, |
|
|
'loan_amt_seqs': loan_amt_seqs} |
|
|
|
|
|
config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 11, 'batch_size': 1, |
|
|
'pad_token_id': 11}) |
|
|
|
|
|
dataset = TPPDatasetV2(input_data) |
|
|
|
|
|
tokenizer = EventTokenizerV2(config) |
|
|
|
|
|
padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy |
|
|
truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy |
|
|
|
|
|
data_collator = TPPDataCollator(tokenizer=tokenizer, |
|
|
return_tensors='pt', |
|
|
max_length=tokenizer.model_max_length, |
|
|
padding=padding, |
|
|
truncation=truncation) |
|
|
|
|
|
data_loader = DataLoader(dataset, collate_fn=data_collator, batch_size=1) |
|
|
|
|
|
return data_loader |
|
|
|
|
|
|
|
|
class NHPV2(NHP): |
|
|
def __init__(self, model_config): |
|
|
super(NHPV2, self).__init__(model_config) |
|
|
|
|
|
self.layer_loan_amt = nn.Linear(1, model_config.hidden_size) |
|
|
|
|
|
self.layer_merge = nn.Linear(model_config.hidden_size * 2, model_config.hidden_size) |
|
|
|
|
|
def forward(self, batch, **kwargs): |
|
|
"""Call the model. |
|
|
|
|
|
Args: |
|
|
batch (tuple, list): batch input. |
|
|
|
|
|
Returns: |
|
|
list: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens; |
|
|
stacked decay states, [batch_size, max_seq_length, 4, hidden_dim], states right after |
|
|
the event happens. |
|
|
""" |
|
|
time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _, type_mask, loan_amt_seq = batch |
|
|
|
|
|
all_hiddens = [] |
|
|
all_outputs = [] |
|
|
all_cells = [] |
|
|
all_cell_bars = [] |
|
|
all_decays = [] |
|
|
|
|
|
max_steps = kwargs.get('max_steps', None) |
|
|
|
|
|
max_decay_time = kwargs.get('max_decay_time', 5.0) |
|
|
|
|
|
|
|
|
max_seq_length = max_steps if max_steps is not None else event_seq.size(1) - 1 |
|
|
|
|
|
batch_size = len(event_seq) |
|
|
c_t, c_bar_t, delta_t, o_t = self.get_init_state(batch_size) |
|
|
h_t = o_t |
|
|
c_t = c_t |
|
|
c_bar_i = c_bar_t |
|
|
|
|
|
|
|
|
if max_seq_length == 1: |
|
|
types_sub_batch = event_seq[:, 0] |
|
|
x_t = self.layer_type_emb(types_sub_batch) |
|
|
|
|
|
|
|
|
loan_t = self.layer_loan_amt(loan_amt_seq[:, 0]) |
|
|
x_t = self.layer_merge(torch.cat(x_t, loan_t)) |
|
|
|
|
|
cell_i, c_bar_i, decay_i, output_i = \ |
|
|
self.rnn_cell(x_t, h_t, c_t, c_bar_i) |
|
|
|
|
|
|
|
|
all_outputs.append(output_i) |
|
|
all_decays.append(decay_i) |
|
|
all_cells.append(cell_i) |
|
|
all_cell_bars.append(c_bar_i) |
|
|
all_hiddens.append(h_t) |
|
|
else: |
|
|
|
|
|
for i in range(max_seq_length): |
|
|
if i == event_seq.size(1) - 1: |
|
|
dt = torch.ones_like(time_delta_seq[:, i]) * max_decay_time |
|
|
else: |
|
|
dt = time_delta_seq[:, i + 1] |
|
|
types_sub_batch = event_seq[:, i] |
|
|
x_t = self.layer_type_emb(types_sub_batch) |
|
|
|
|
|
|
|
|
loan_t = self.layer_loan_amt(loan_amt_seq[:, i:i+1]) |
|
|
x_t = self.layer_merge(torch.cat([x_t, loan_t], dim=-1)) |
|
|
|
|
|
|
|
|
cell_i, c_bar_i, decay_i, output_i = \ |
|
|
self.rnn_cell(x_t, h_t, c_t, c_bar_i) |
|
|
|
|
|
|
|
|
c_t, h_t = self.rnn_cell.decay(cell_i, |
|
|
c_bar_i, |
|
|
decay_i, |
|
|
output_i, |
|
|
dt[:, None]) |
|
|
|
|
|
|
|
|
all_outputs.append(output_i) |
|
|
all_decays.append(decay_i) |
|
|
all_cells.append(cell_i) |
|
|
all_cell_bars.append(c_bar_i) |
|
|
all_hiddens.append(h_t) |
|
|
|
|
|
|
|
|
cell_stack = torch.stack(all_cells, dim=1) |
|
|
cell_bar_stack = torch.stack(all_cell_bars, dim=1) |
|
|
decay_stack = torch.stack(all_decays, dim=1) |
|
|
output_stack = torch.stack(all_outputs, dim=1) |
|
|
|
|
|
|
|
|
hiddens_stack = torch.stack(all_hiddens, dim=1) |
|
|
|
|
|
|
|
|
decay_states_stack = torch.stack((cell_stack, |
|
|
cell_bar_stack, |
|
|
decay_stack, |
|
|
output_stack), |
|
|
dim=2) |
|
|
|
|
|
return hiddens_stack, decay_states_stack |
|
|
|
|
|
def loglike_loss(self, batch): |
|
|
"""Compute the loglike loss. |
|
|
|
|
|
Args: |
|
|
batch (list): batch input. |
|
|
|
|
|
Returns: |
|
|
list: loglike loss, num events. |
|
|
""" |
|
|
time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _, type_mask, loan_amt_seq = batch |
|
|
|
|
|
hiddens_ti, decay_states = self.forward(batch) |
|
|
|
|
|
|
|
|
batch_size, seq_len, _ = hiddens_ti.size() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_at_event = self.layer_intensity(hiddens_ti) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
interval_t_sample = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) |
|
|
|
|
|
|
|
|
state_t_sample = self.compute_states_at_sample_times(decay_states, interval_t_sample) |
|
|
|
|
|
|
|
|
lambda_t_sample = self.layer_intensity(state_t_sample) |
|
|
|
|
|
type_seqs = type_seqs.long() |
|
|
event_ll, non_event_ll, num_events = self.compute_loglikelihood( |
|
|
time_delta_seq=time_delta_seqs[:, 1:], |
|
|
lambda_at_event=lambda_at_event, |
|
|
lambdas_loss_samples=lambda_t_sample, |
|
|
seq_mask=batch_non_pad_mask[:, 1:], |
|
|
type_seq=type_seqs[:, 1:] |
|
|
) |
|
|
|
|
|
|
|
|
loss = - (event_ll - non_event_ll).sum() |
|
|
return loss, num_events |
|
|
|
|
|
def compute_states_at_sample_times(self, decay_states, sample_dtimes): |
|
|
""" |
|
|
decay_states: (batch_size, seq_len, 4, hidden_dim) |
|
|
sample_dtimes: (batch_size, seq_len, num_mc_sample) |
|
|
""" |
|
|
cell_stack, cell_bar_stack, decay_stack, output_stack = torch.unbind(decay_states, dim=2) |
|
|
|
|
|
_, h_ts = self.rnn_cell.decay( |
|
|
cell_stack[:, :, None, :], |
|
|
cell_bar_stack[:, :, None, :], |
|
|
decay_stack[:, :, None, :], |
|
|
output_stack[:, :, None, :], |
|
|
sample_dtimes[..., None] |
|
|
) |
|
|
return h_ts |
|
|
|
|
|
def make_model(): |
|
|
config = Config.build_from_yaml_file('examples/configs/experiment_config.yaml', experiment_id='NHP_train') |
|
|
model_config = config.model_config |
|
|
|
|
|
|
|
|
model_config.num_event_types = 11 |
|
|
model_config.num_event_types_pad = 12 |
|
|
model_config.pad_token_id = 11 |
|
|
|
|
|
model = NHPV2(model_config) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def main(): |
|
|
data_loader = make_data_loader() |
|
|
|
|
|
model = make_model() |
|
|
|
|
|
num_epochs = 10 |
|
|
|
|
|
opt = torch.optim.Adam(model.parameters(), lr=0.001) |
|
|
|
|
|
for i in range(num_epochs): |
|
|
total_loss = 0 |
|
|
total_num_event = 0 |
|
|
for batch in data_loader: |
|
|
with torch.set_grad_enabled(True): |
|
|
batch_loss, batch_num_event = model.loglike_loss(batch = batch.values()) |
|
|
|
|
|
opt.zero_grad() |
|
|
batch_loss.backward() |
|
|
opt.step() |
|
|
|
|
|
total_loss += batch_loss |
|
|
total_num_event += batch_num_event |
|
|
|
|
|
avg_loss = total_loss / total_num_event |
|
|
print(f'epochs {i}: loss {avg_loss}') |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|