File size: 9,140 Bytes
f43af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
from easy_tpp.preprocess.dataset import TPPDataset
from easy_tpp.preprocess.dataset import get_data_loader
from easy_tpp.preprocess.event_tokenizer import EventTokenizer
from easy_tpp.utils import load_pickle, py_assert


class TPPDataLoader:
    def __init__(self, data_config, **kwargs):
        """Initialize the dataloader

        Args:
            data_config (EasyTPP.DataConfig): data config.
            backend (str): backend engine, defaults to 'torch'.
        """
        self.data_config = data_config
        self.num_event_types = data_config.data_specs.num_event_types
        self.backend = kwargs.get('backend', 'torch')
        self.kwargs = kwargs

    def build_input(self, source_dir, data_format, split):
        """Helper function to load and process dataset based on file format.

        Args:
            source_dir (str): Path to dataset directory.
            split (str): Dataset split, e.g., 'train', 'dev', 'test'.

        Returns:
            dict: Dictionary containing sequences of event times, types, and intervals.
        """

        if data_format == 'pkl':
            return self._build_input_from_pkl(source_dir, split)
        elif data_format == 'json':
            return self._build_input_from_json(source_dir, split)
        else:
            raise ValueError(f"Unsupported file format: {data_format}")

    def _build_input_from_pkl(self, source_dir, split):
        """Load and process data from a pickle file.

        Args:
            source_dir (str): Path to the pickle file.
            split (str): Dataset split, e.g., 'train', 'dev', 'test'.

        Returns:
            dict: Dictionary with processed event sequences.
        """
        data = load_pickle(source_dir)
        py_assert(data["dim_process"] == self.num_event_types,
                  ValueError, "Inconsistent dim_process in different splits.")

        source_data = data[split]
        return {
            'time_seqs': [[x["time_since_start"] for x in seq] for seq in source_data],
            'type_seqs': [[x["type_event"] for x in seq] for seq in source_data],
            'time_delta_seqs': [[x["time_since_last_event"] for x in seq] for seq in source_data]
        }

    def _build_input_from_json(self, source_dir, split):
        """Load and process data from a JSON file.

        Args:
            source_dir (str): Path to the JSON file or Hugging Face dataset name.
            split (str): Dataset split, e.g., 'train', 'dev', 'test'.

        Returns:
            dict: Dictionary with processed event sequences.
        """
        from datasets import load_dataset
        split_mapped = 'validation' if split == 'dev' else split
        if source_dir.endswith('.json'):
            data = load_dataset('json', data_files={split_mapped: source_dir}, split=split_mapped)
        elif source_dir.startswith('easytpp'):
            data = load_dataset(source_dir, split=split_mapped)
        else:
            raise ValueError("Unsupported source directory format for JSON.")

        py_assert(data['dim_process'][0] == self.num_event_types,
                  ValueError, "Inconsistent dim_process in different splits.")

        return {
            'time_seqs': data['time_since_start'],
            'type_seqs': data['type_event'],
            'time_delta_seqs': data['time_since_last_event']
        }

    def get_loader(self, split='train', **kwargs):
        """Get the corresponding data loader.

        Args:
            split (str, optional): denote the train, valid and test set. Defaults to 'train'.
            num_event_types (int, optional): num of event types in the data. Defaults to None.

        Raises:
            NotImplementedError: the input of 'num_event_types' is inconsistent with the data.

        Returns:
            EasyTPP.DataLoader: the data loader for tpp data.
        """
        data_dir = self.data_config.get_data_dir(split)
        data = self.build_input(data_dir, self.data_config.data_format, split)

        dataset = TPPDataset(data)
        tokenizer = EventTokenizer(self.data_config.data_specs)

        # Remove 'shuffle' from kwargs if it exists to avoid conflict
        shuffle = kwargs.pop('shuffle', self.kwargs.get('shuffle', False))

        loader = get_data_loader(dataset,
                                 self.backend,
                                 tokenizer,
                                 batch_size=self.kwargs['batch_size'],
                                 shuffle=shuffle,
                                 **kwargs)

        return loader

    def train_loader(self, **kwargs):
        """Return the train loader

        Returns:
            EasyTPP.DataLoader: data loader for train set.
        """
        return self.get_loader('train', **kwargs)

    def valid_loader(self, **kwargs):
        """Return the valid loader

        Returns:
            EasyTPP.DataLoader: data loader for valid set.
        """
        return self.get_loader('dev', **kwargs)

    def test_loader(self, **kwargs):
        """Return the test loader

        Returns:
            EasyTPP.DataLoader: data loader for test set.
        """
        # for test set, we do not shuffle
        kwargs['shuffle'] = False
        return self.get_loader('test', **kwargs)

    def get_statistics(self, split='train'):
        """Get basic statistics about the dataset.

        Args:
            split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.

        Returns:
            dict: Dictionary containing statistics about the dataset.
        """
        data_dir = self.data_config.get_data_dir(split)
        data = self.build_input(data_dir, self.data_config.data_format, split)

        num_sequences = len(data['time_seqs'])
        sequence_lengths = [len(seq) for seq in data['time_seqs']]
        avg_sequence_length = sum(sequence_lengths) / num_sequences
        all_event_types = [event for seq in data['type_seqs'] for event in seq]
        event_type_counts = Counter(all_event_types)

        # Calculate time_delta_seqs statistics
        all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq]
        mean_time_delta = np.mean(all_time_deltas) if all_time_deltas else 0
        min_time_delta = np.min(all_time_deltas) if all_time_deltas else 0
        max_time_delta = np.max(all_time_deltas) if all_time_deltas else 0

        stats = {
            "num_sequences": num_sequences,
            "avg_sequence_length": avg_sequence_length,
            "event_type_distribution": dict(event_type_counts),
            "max_sequence_length": max(sequence_lengths),
            "min_sequence_length": min(sequence_lengths),
            "mean_time_delta": mean_time_delta,
            "min_time_delta": min_time_delta,
            "max_time_delta": max_time_delta
        }

        return stats

    def plot_event_type_distribution(self, split='train'):
        """Plot the distribution of event types in the dataset.

        Args:
            split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
        """
        stats = self.get_statistics(split)
        event_type_distribution = stats['event_type_distribution']

        plt.figure(figsize=(8, 6))
        plt.bar(event_type_distribution.keys(), event_type_distribution.values(), color='skyblue')
        plt.xlabel('Event Types')
        plt.ylabel('Frequency')
        plt.title(f'Event Type Distribution ({split} set)')
        plt.show()

    def plot_event_delta_times_distribution(self, split='train'):
        """Plot the distribution of event delta times in the dataset.

        Args:
            split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
        """
        data_dir = self.data_config.get_data_dir(split)
        data = self.build_input(data_dir, self.data_config.data_format, split)

        # Flatten the time_delta_seqs to get all delta times
        all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq]

        plt.figure(figsize=(10, 6))
        plt.hist(all_time_deltas, bins=30, color='skyblue', edgecolor='black')
        plt.xlabel('Event Delta Times')
        plt.ylabel('Frequency')
        plt.title(f'Event Delta Times Distribution ({split} set)')
        plt.grid(axis='y', alpha=0.75)
        plt.show()

    def plot_sequence_length_distribution(self, split='train'):
        """Plot the distribution of sequence lengths in the dataset.

        Args:
            split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
        """
        data_dir = self.data_config.get_data_dir(split)
        data = self.build_input(data_dir, self.data_config.data_format, split)
        sequence_lengths = [len(seq) for seq in data['time_seqs']]

        plt.figure(figsize=(8, 6))
        plt.hist(sequence_lengths, bins=10, color='salmon', edgecolor='black')
        plt.xlabel('Sequence Length')
        plt.ylabel('Frequency')
        plt.title(f'Sequence Length Distribution ({split} set)')
        plt.show()