File size: 3,053 Bytes
f43af3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import math
from typing import Dict
import numpy as np
from torch.utils.data import Dataset, DataLoader
from easy_tpp.preprocess.data_collator import TPPDataCollator
from easy_tpp.preprocess.event_tokenizer import EventTokenizer
from easy_tpp.utils import py_assert
class TPPDataset(Dataset):
def __init__(self, data: Dict):
self.data_dict = data
self.time_seqs = self.data_dict['time_seqs']
self.time_delta_seqs = self.data_dict['time_delta_seqs']
self.type_seqs = self.data_dict['type_seqs']
def __len__(self):
"""
Returns: length of the dataset
"""
py_assert(len(self.time_seqs) == len(self.type_seqs) and len(self.time_delta_seqs) == len(self.type_seqs),
ValueError,
f"Inconsistent lengths for data! time_seq_len:{len(self.time_seqs)}, event_len: "
f"{len(self.type_seqs)}, time_delta_seq_len: {len(self.time_delta_seqs)}")
return len(self.time_seqs)
def __getitem__(self, idx):
"""
Args:
idx: iteration index
Returns:
dict: a dict of time_seqs, time_delta_seqs and type_seqs element
"""
return dict({'time_seqs': self.time_seqs[idx], 'time_delta_seqs': self.time_delta_seqs[idx],
'type_seqs': self.type_seqs[idx]})
def get_dt_stats(self):
x_bar, s_2_x, n = 0., 0., 0
min_dt, max_dt = np.inf, -np.inf
for dts, marks in zip(self.time_delta_seqs, self.type_seqs):
dts = np.array(dts[1:-1 if marks[-1] == -1 else None])
min_dt = min(min_dt, dts.min())
max_dt = max(max_dt, dts.max())
y_bar = dts.mean()
s_2_y = dts.var()
m = dts.shape[0]
n += m
# Formula taken from https://math.stackexchange.com/questions/3604607/can-i-work-out-the-variance-in-batches
s_2_x = (((n - 1) * s_2_x + (m - 1) * s_2_y) / (n + m - 1)) + (
(n * m * ((x_bar - y_bar) ** 2)) / ((n + m) * (n + m - 1)))
x_bar = (n * x_bar + m * y_bar) / (n + m)
print(x_bar, (s_2_x ** 0.5))
print(f'min_dt: {min_dt}')
print(f'max_dt: {max_dt}')
return x_bar, (s_2_x ** 0.5), min_dt, max_dt
def get_data_loader(dataset: TPPDataset, backend: str, tokenizer: EventTokenizer, **kwargs):
assert backend == 'torch', 'Only torch backend is supported.'
padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy
truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy
data_collator = TPPDataCollator(tokenizer=tokenizer,
return_tensors='pt',
max_length=tokenizer.model_max_length,
padding=padding,
truncation=truncation)
return DataLoader(dataset,
collate_fn=data_collator,
**kwargs)
|