Spaces:
Runtime error
Runtime error
File size: 4,612 Bytes
6f2ff70 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | from torch.utils.data import Dataset
import tqdm
import torch
import random
import numpy as np
class BERTDataset(Dataset):
def __init__(self, corpus_path, vocab, seq_len, corpus_lines=None, encoding="utf-8", on_memory=True, predict_mode=False):
self.vocab = vocab
self.seq_len = seq_len
self.on_memory = on_memory
self.corpus_lines = corpus_lines
self.corpus_path = corpus_path
self.encoding = encoding
self.predict_mode = predict_mode
self.lines = corpus_path
self.corpus_lines = len(self.lines)
if not on_memory:
self.file = open(corpus_path, "r", encoding=encoding)
self.random_file = open(corpus_path, "r", encoding=encoding)
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
self.random_file.__next__()
def __len__(self):
return self.corpus_lines
def __getitem__(self, item):
t1, t2, is_next_label = self.random_sent(item)
t1_random, t1_label = self.random_word(t1)
t2_random, t2_label = self.random_word(t2)
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
t2 = t2_random + [self.vocab.eos_index]
t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
t2_label = t2_label + [self.vocab.pad_index]
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
bert_input = (t1 + t2)[:self.seq_len]
bert_label = (t1_label + t2_label)[:self.seq_len]
padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]
bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)
output = {"bert_input": bert_input,
"bert_label": bert_label,
"segment_label": segment_label,
"is_next": is_next_label}
return {key: torch.tensor(value) for key, value in output.items()}
def random_word(self, sentence):
tokens = list(sentence)
output_label = []
for i, token in enumerate(tokens):
prob = random.random()
# replace 15% of tokens in a sequence to a masked token
if prob < 0.15:
if self.predict_mode:
tokens[i] = self.vocab.mask_index
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
continue
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
tokens[i] = self.vocab.mask_index
# 10% randomly change token to random token
elif prob < 0.9:
tokens[i] = random.randrange(len(self.vocab))
# 10% randomly change token to current token
else:
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
else:
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
output_label.append(0)
return tokens, output_label
def random_sent(self, index):
t1, t2 = self.get_corpus_line(index)
if self.predict_mode:
return t1, t2, 1
# output_text, label(isNotNext:0, isNext:1)
if random.random() > 0.5:
return t1, t2, 1
else:
return t1, self.get_random_line(), 0
def get_corpus_line(self, item):
if self.on_memory:
return self.lines[item][0], self.lines[item][1]
else:
line = self.file.__next__()
if line is None:
self.file.close()
self.file = open(self.corpus_path, "r", encoding=self.encoding)
line = self.file.__next__()
t1, t2 = line[:-1].split("\t")
return t1, t2
def get_random_line(self):
if self.on_memory:
return self.lines[random.randrange(len(self.lines))][1]
line = self.file.__next__()
if line is None:
self.file.close()
self.file = open(self.corpus_path, "r", encoding=self.encoding)
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
self.random_file.__next__()
line = self.random_file.__next__()
return line[:-1].split("\t")[1]
|