from torch.utils.data import Dataset import torch import random import numpy as np from collections import defaultdict class LogDataset(Dataset): def __init__(self, log_corpus, time_corpus, vocab, seq_len, corpus_lines=None, encoding="utf-8", on_memory=True, predict_mode=False, mask_ratio=0.15): """ :param corpus: log sessions/line :param vocab: log events collection including pad, ukn ... :param seq_len: max sequence length :param corpus_lines: number of log sessions :param encoding: :param on_memory: :param predict_mode: if predict """ self.vocab = vocab self.seq_len = seq_len self.on_memory = on_memory self.encoding = encoding self.predict_mode = predict_mode self.log_corpus = log_corpus self.time_corpus = time_corpus self.corpus_lines = len(log_corpus) self.mask_ratio = mask_ratio def __len__(self): return self.corpus_lines def __getitem__(self, idx): k, t = self.log_corpus[idx], self.time_corpus[idx] k_masked, k_label, t_masked, t_label = self.random_item(k, t) # [CLS] tag = SOS tag, [SEP] tag = EOS tag k = [self.vocab.sos_index] + k_masked k_label = [self.vocab.pad_index] + k_label # k_label = [self.vocab.sos_index] + k_label t = [0] + t_masked t_label = [self.vocab.pad_index] + t_label return k, k_label, t, t_label def random_item(self, k, t): tokens = list(k) output_label = [] time_intervals = list(t) time_label = [] for i, token in enumerate(tokens): time_int = time_intervals[i] prob = random.random() # replace 15% of tokens in a sequence to a masked token if prob < self.mask_ratio: # raise AttributeError("no mask in visualization") if self.predict_mode: tokens[i] = self.vocab.mask_index output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) time_label.append(time_int) time_intervals[i] = 0 continue prob /= self.mask_ratio # 80% randomly change token to mask token if prob < 0.8: tokens[i] = self.vocab.mask_index # 10% randomly change token to random token elif prob < 0.9: tokens[i] = random.randrange(len(self.vocab)) # 10% randomly change token to current token else: tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) time_intervals[i] = 0 # time mask value = 0 time_label.append(time_int) else: tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) output_label.append(0) time_label.append(0) return tokens, output_label, time_intervals, time_label def collate_fn(self, batch, percentile=100, dynamical_pad=True): lens = [len(seq[0]) for seq in batch] # find the max len in each batch if dynamical_pad: # dynamical padding seq_len = int(np.percentile(lens, percentile)) if self.seq_len is not None: seq_len = min(seq_len, self.seq_len) else: # fixed length padding seq_len = self.seq_len output = defaultdict(list) for seq in batch: bert_input = seq[0][:seq_len] bert_label = seq[1][:seq_len] time_input = seq[2][:seq_len] time_label = seq[3][:seq_len] padding = [self.vocab.pad_index for _ in range(seq_len - len(bert_input))] bert_input.extend(padding), bert_label.extend(padding), time_input.extend(padding), time_label.extend( padding) time_input = np.array(time_input)[:, np.newaxis] output["bert_input"].append(bert_input) output["bert_label"].append(bert_label) output["time_input"].append(time_input) output["time_label"].append(time_label) output["bert_input"] = torch.tensor(output["bert_input"], dtype=torch.long) output["bert_label"] = torch.tensor(output["bert_label"], dtype=torch.long) output["time_input"] = torch.tensor(output["time_input"], dtype=torch.float) output["time_label"] = torch.tensor(output["time_label"], dtype=torch.float) return output