|
|
import math |
|
|
from typing import Dict |
|
|
|
|
|
import numpy as np |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
from easy_tpp.preprocess.data_collator import TPPDataCollator |
|
|
from easy_tpp.preprocess.event_tokenizer import EventTokenizer |
|
|
from easy_tpp.utils import py_assert |
|
|
|
|
|
|
|
|
class TPPDataset(Dataset): |
|
|
def __init__(self, data: Dict): |
|
|
self.data_dict = data |
|
|
self.time_seqs = self.data_dict['time_seqs'] |
|
|
self.time_delta_seqs = self.data_dict['time_delta_seqs'] |
|
|
self.type_seqs = self.data_dict['type_seqs'] |
|
|
|
|
|
def __len__(self): |
|
|
""" |
|
|
|
|
|
Returns: length of the dataset |
|
|
|
|
|
""" |
|
|
|
|
|
py_assert(len(self.time_seqs) == len(self.type_seqs) and len(self.time_delta_seqs) == len(self.type_seqs), |
|
|
ValueError, |
|
|
f"Inconsistent lengths for data! time_seq_len:{len(self.time_seqs)}, event_len: " |
|
|
f"{len(self.type_seqs)}, time_delta_seq_len: {len(self.time_delta_seqs)}") |
|
|
|
|
|
return len(self.time_seqs) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
idx: iteration index |
|
|
|
|
|
Returns: |
|
|
dict: a dict of time_seqs, time_delta_seqs and type_seqs element |
|
|
|
|
|
""" |
|
|
return dict({'time_seqs': self.time_seqs[idx], 'time_delta_seqs': self.time_delta_seqs[idx], |
|
|
'type_seqs': self.type_seqs[idx]}) |
|
|
|
|
|
def get_dt_stats(self): |
|
|
x_bar, s_2_x, n = 0., 0., 0 |
|
|
min_dt, max_dt = np.inf, -np.inf |
|
|
|
|
|
for dts, marks in zip(self.time_delta_seqs, self.type_seqs): |
|
|
dts = np.array(dts[1:-1 if marks[-1] == -1 else None]) |
|
|
min_dt = min(min_dt, dts.min()) |
|
|
max_dt = max(max_dt, dts.max()) |
|
|
y_bar = dts.mean() |
|
|
s_2_y = dts.var() |
|
|
m = dts.shape[0] |
|
|
n += m |
|
|
|
|
|
s_2_x = (((n - 1) * s_2_x + (m - 1) * s_2_y) / (n + m - 1)) + ( |
|
|
(n * m * ((x_bar - y_bar) ** 2)) / ((n + m) * (n + m - 1))) |
|
|
x_bar = (n * x_bar + m * y_bar) / (n + m) |
|
|
|
|
|
print(x_bar, (s_2_x ** 0.5)) |
|
|
print(f'min_dt: {min_dt}') |
|
|
print(f'max_dt: {max_dt}') |
|
|
return x_bar, (s_2_x ** 0.5), min_dt, max_dt |
|
|
|
|
|
|
|
|
def get_data_loader(dataset: TPPDataset, backend: str, tokenizer: EventTokenizer, **kwargs): |
|
|
assert backend == 'torch', 'Only torch backend is supported.' |
|
|
padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy |
|
|
truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy |
|
|
data_collator = TPPDataCollator(tokenizer=tokenizer, |
|
|
return_tensors='pt', |
|
|
max_length=tokenizer.model_max_length, |
|
|
padding=padding, |
|
|
truncation=truncation) |
|
|
return DataLoader(dataset, |
|
|
collate_fn=data_collator, |
|
|
**kwargs) |
|
|
|