MukeshKapoor25's picture
changs
6f2ff70
from torch.utils.data import Dataset
import torch
import random
import numpy as np
from collections import defaultdict
class LogDataset(Dataset):
def __init__(self, log_corpus, time_corpus, vocab, seq_len, corpus_lines=None, encoding="utf-8", on_memory=True, predict_mode=False, mask_ratio=0.15):
"""
:param corpus: log sessions/line
:param vocab: log events collection including pad, ukn ...
:param seq_len: max sequence length
:param corpus_lines: number of log sessions
:param encoding:
:param on_memory:
:param predict_mode: if predict
"""
self.vocab = vocab
self.seq_len = seq_len
self.on_memory = on_memory
self.encoding = encoding
self.predict_mode = predict_mode
self.log_corpus = log_corpus
self.time_corpus = time_corpus
self.corpus_lines = len(log_corpus)
self.mask_ratio = mask_ratio
def __len__(self):
return self.corpus_lines
def __getitem__(self, idx):
k, t = self.log_corpus[idx], self.time_corpus[idx]
k_masked, k_label, t_masked, t_label = self.random_item(k, t)
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
k = [self.vocab.sos_index] + k_masked
k_label = [self.vocab.pad_index] + k_label
# k_label = [self.vocab.sos_index] + k_label
t = [0] + t_masked
t_label = [self.vocab.pad_index] + t_label
return k, k_label, t, t_label
def random_item(self, k, t):
tokens = list(k)
output_label = []
time_intervals = list(t)
time_label = []
for i, token in enumerate(tokens):
time_int = time_intervals[i]
prob = random.random()
# replace 15% of tokens in a sequence to a masked token
if prob < self.mask_ratio:
# raise AttributeError("no mask in visualization")
if self.predict_mode:
tokens[i] = self.vocab.mask_index
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
time_label.append(time_int)
time_intervals[i] = 0
continue
prob /= self.mask_ratio
# 80% randomly change token to mask token
if prob < 0.8:
tokens[i] = self.vocab.mask_index
# 10% randomly change token to random token
elif prob < 0.9:
tokens[i] = random.randrange(len(self.vocab))
# 10% randomly change token to current token
else:
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
time_intervals[i] = 0 # time mask value = 0
time_label.append(time_int)
else:
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
output_label.append(0)
time_label.append(0)
return tokens, output_label, time_intervals, time_label
def collate_fn(self, batch, percentile=100, dynamical_pad=True):
lens = [len(seq[0]) for seq in batch]
# find the max len in each batch
if dynamical_pad:
# dynamical padding
seq_len = int(np.percentile(lens, percentile))
if self.seq_len is not None:
seq_len = min(seq_len, self.seq_len)
else:
# fixed length padding
seq_len = self.seq_len
output = defaultdict(list)
for seq in batch:
bert_input = seq[0][:seq_len]
bert_label = seq[1][:seq_len]
time_input = seq[2][:seq_len]
time_label = seq[3][:seq_len]
padding = [self.vocab.pad_index for _ in range(seq_len - len(bert_input))]
bert_input.extend(padding), bert_label.extend(padding), time_input.extend(padding), time_label.extend(
padding)
time_input = np.array(time_input)[:, np.newaxis]
output["bert_input"].append(bert_input)
output["bert_label"].append(bert_label)
output["time_input"].append(time_input)
output["time_label"].append(time_label)
output["bert_input"] = torch.tensor(output["bert_input"], dtype=torch.long)
output["bert_label"] = torch.tensor(output["bert_label"], dtype=torch.long)
output["time_input"] = torch.tensor(output["time_input"], dtype=torch.float)
output["time_label"] = torch.tensor(output["time_label"], dtype=torch.float)
return output