Spaces:
Runtime error
Runtime error
File size: 4,713 Bytes
6f2ff70 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | from torch.utils.data import Dataset
import torch
import random
import numpy as np
from collections import defaultdict
class LogDataset(Dataset):
def __init__(self, log_corpus, time_corpus, vocab, seq_len, corpus_lines=None, encoding="utf-8", on_memory=True, predict_mode=False, mask_ratio=0.15):
"""
:param corpus: log sessions/line
:param vocab: log events collection including pad, ukn ...
:param seq_len: max sequence length
:param corpus_lines: number of log sessions
:param encoding:
:param on_memory:
:param predict_mode: if predict
"""
self.vocab = vocab
self.seq_len = seq_len
self.on_memory = on_memory
self.encoding = encoding
self.predict_mode = predict_mode
self.log_corpus = log_corpus
self.time_corpus = time_corpus
self.corpus_lines = len(log_corpus)
self.mask_ratio = mask_ratio
def __len__(self):
return self.corpus_lines
def __getitem__(self, idx):
k, t = self.log_corpus[idx], self.time_corpus[idx]
k_masked, k_label, t_masked, t_label = self.random_item(k, t)
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
k = [self.vocab.sos_index] + k_masked
k_label = [self.vocab.pad_index] + k_label
# k_label = [self.vocab.sos_index] + k_label
t = [0] + t_masked
t_label = [self.vocab.pad_index] + t_label
return k, k_label, t, t_label
def random_item(self, k, t):
tokens = list(k)
output_label = []
time_intervals = list(t)
time_label = []
for i, token in enumerate(tokens):
time_int = time_intervals[i]
prob = random.random()
# replace 15% of tokens in a sequence to a masked token
if prob < self.mask_ratio:
# raise AttributeError("no mask in visualization")
if self.predict_mode:
tokens[i] = self.vocab.mask_index
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
time_label.append(time_int)
time_intervals[i] = 0
continue
prob /= self.mask_ratio
# 80% randomly change token to mask token
if prob < 0.8:
tokens[i] = self.vocab.mask_index
# 10% randomly change token to random token
elif prob < 0.9:
tokens[i] = random.randrange(len(self.vocab))
# 10% randomly change token to current token
else:
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
time_intervals[i] = 0 # time mask value = 0
time_label.append(time_int)
else:
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
output_label.append(0)
time_label.append(0)
return tokens, output_label, time_intervals, time_label
def collate_fn(self, batch, percentile=100, dynamical_pad=True):
lens = [len(seq[0]) for seq in batch]
# find the max len in each batch
if dynamical_pad:
# dynamical padding
seq_len = int(np.percentile(lens, percentile))
if self.seq_len is not None:
seq_len = min(seq_len, self.seq_len)
else:
# fixed length padding
seq_len = self.seq_len
output = defaultdict(list)
for seq in batch:
bert_input = seq[0][:seq_len]
bert_label = seq[1][:seq_len]
time_input = seq[2][:seq_len]
time_label = seq[3][:seq_len]
padding = [self.vocab.pad_index for _ in range(seq_len - len(bert_input))]
bert_input.extend(padding), bert_label.extend(padding), time_input.extend(padding), time_label.extend(
padding)
time_input = np.array(time_input)[:, np.newaxis]
output["bert_input"].append(bert_input)
output["bert_label"].append(bert_label)
output["time_input"].append(time_input)
output["time_label"].append(time_label)
output["bert_input"] = torch.tensor(output["bert_input"], dtype=torch.long)
output["bert_label"] = torch.tensor(output["bert_label"], dtype=torch.long)
output["time_input"] = torch.tensor(output["time_input"], dtype=torch.float)
output["time_label"] = torch.tensor(output["time_label"], dtype=torch.float)
return output
|