from torch.utils.data import Dataset import tqdm import torch import random import numpy as np class BERTDataset(Dataset): def __init__(self, corpus_path, vocab, seq_len, corpus_lines=None, encoding="utf-8", on_memory=True, predict_mode=False): self.vocab = vocab self.seq_len = seq_len self.on_memory = on_memory self.corpus_lines = corpus_lines self.corpus_path = corpus_path self.encoding = encoding self.predict_mode = predict_mode self.lines = corpus_path self.corpus_lines = len(self.lines) if not on_memory: self.file = open(corpus_path, "r", encoding=encoding) self.random_file = open(corpus_path, "r", encoding=encoding) for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): self.random_file.__next__() def __len__(self): return self.corpus_lines def __getitem__(self, item): t1, t2, is_next_label = self.random_sent(item) t1_random, t1_label = self.random_word(t1) t2_random, t2_label = self.random_word(t2) # [CLS] tag = SOS tag, [SEP] tag = EOS tag t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index] t2 = t2_random + [self.vocab.eos_index] t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index] t2_label = t2_label + [self.vocab.pad_index] segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] bert_input = (t1 + t2)[:self.seq_len] bert_label = (t1_label + t2_label)[:self.seq_len] padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))] bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding) output = {"bert_input": bert_input, "bert_label": bert_label, "segment_label": segment_label, "is_next": is_next_label} return {key: torch.tensor(value) for key, value in output.items()} def random_word(self, sentence): tokens = list(sentence) output_label = [] for i, token in enumerate(tokens): prob = random.random() # replace 15% of tokens in a sequence to a masked token if prob < 0.15: if self.predict_mode: tokens[i] = self.vocab.mask_index output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) continue prob /= 0.15 # 80% randomly change token to mask token if prob < 0.8: tokens[i] = self.vocab.mask_index # 10% randomly change token to random token elif prob < 0.9: tokens[i] = random.randrange(len(self.vocab)) # 10% randomly change token to current token else: tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) else: tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) output_label.append(0) return tokens, output_label def random_sent(self, index): t1, t2 = self.get_corpus_line(index) if self.predict_mode: return t1, t2, 1 # output_text, label(isNotNext:0, isNext:1) if random.random() > 0.5: return t1, t2, 1 else: return t1, self.get_random_line(), 0 def get_corpus_line(self, item): if self.on_memory: return self.lines[item][0], self.lines[item][1] else: line = self.file.__next__() if line is None: self.file.close() self.file = open(self.corpus_path, "r", encoding=self.encoding) line = self.file.__next__() t1, t2 = line[:-1].split("\t") return t1, t2 def get_random_line(self): if self.on_memory: return self.lines[random.randrange(len(self.lines))][1] line = self.file.__next__() if line is None: self.file.close() self.file = open(self.corpus_path, "r", encoding=self.encoding) for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): self.random_file.__next__() line = self.random_file.__next__() return line[:-1].split("\t")[1]