Spaces:
Runtime error
Runtime error
| from torch.utils.data import Dataset | |
| import tqdm | |
| import torch | |
| import random | |
| import numpy as np | |
| class BERTDataset(Dataset): | |
| def __init__(self, corpus_path, vocab, seq_len, corpus_lines=None, encoding="utf-8", on_memory=True, predict_mode=False): | |
| self.vocab = vocab | |
| self.seq_len = seq_len | |
| self.on_memory = on_memory | |
| self.corpus_lines = corpus_lines | |
| self.corpus_path = corpus_path | |
| self.encoding = encoding | |
| self.predict_mode = predict_mode | |
| self.lines = corpus_path | |
| self.corpus_lines = len(self.lines) | |
| if not on_memory: | |
| self.file = open(corpus_path, "r", encoding=encoding) | |
| self.random_file = open(corpus_path, "r", encoding=encoding) | |
| for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): | |
| self.random_file.__next__() | |
| def __len__(self): | |
| return self.corpus_lines | |
| def __getitem__(self, item): | |
| t1, t2, is_next_label = self.random_sent(item) | |
| t1_random, t1_label = self.random_word(t1) | |
| t2_random, t2_label = self.random_word(t2) | |
| # [CLS] tag = SOS tag, [SEP] tag = EOS tag | |
| t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index] | |
| t2 = t2_random + [self.vocab.eos_index] | |
| t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index] | |
| t2_label = t2_label + [self.vocab.pad_index] | |
| segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] | |
| bert_input = (t1 + t2)[:self.seq_len] | |
| bert_label = (t1_label + t2_label)[:self.seq_len] | |
| padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))] | |
| bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding) | |
| output = {"bert_input": bert_input, | |
| "bert_label": bert_label, | |
| "segment_label": segment_label, | |
| "is_next": is_next_label} | |
| return {key: torch.tensor(value) for key, value in output.items()} | |
| def random_word(self, sentence): | |
| tokens = list(sentence) | |
| output_label = [] | |
| for i, token in enumerate(tokens): | |
| prob = random.random() | |
| # replace 15% of tokens in a sequence to a masked token | |
| if prob < 0.15: | |
| if self.predict_mode: | |
| tokens[i] = self.vocab.mask_index | |
| output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) | |
| continue | |
| prob /= 0.15 | |
| # 80% randomly change token to mask token | |
| if prob < 0.8: | |
| tokens[i] = self.vocab.mask_index | |
| # 10% randomly change token to random token | |
| elif prob < 0.9: | |
| tokens[i] = random.randrange(len(self.vocab)) | |
| # 10% randomly change token to current token | |
| else: | |
| tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) | |
| output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) | |
| else: | |
| tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) | |
| output_label.append(0) | |
| return tokens, output_label | |
| def random_sent(self, index): | |
| t1, t2 = self.get_corpus_line(index) | |
| if self.predict_mode: | |
| return t1, t2, 1 | |
| # output_text, label(isNotNext:0, isNext:1) | |
| if random.random() > 0.5: | |
| return t1, t2, 1 | |
| else: | |
| return t1, self.get_random_line(), 0 | |
| def get_corpus_line(self, item): | |
| if self.on_memory: | |
| return self.lines[item][0], self.lines[item][1] | |
| else: | |
| line = self.file.__next__() | |
| if line is None: | |
| self.file.close() | |
| self.file = open(self.corpus_path, "r", encoding=self.encoding) | |
| line = self.file.__next__() | |
| t1, t2 = line[:-1].split("\t") | |
| return t1, t2 | |
| def get_random_line(self): | |
| if self.on_memory: | |
| return self.lines[random.randrange(len(self.lines))][1] | |
| line = self.file.__next__() | |
| if line is None: | |
| self.file.close() | |
| self.file = open(self.corpus_path, "r", encoding=self.encoding) | |
| for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): | |
| self.random_file.__next__() | |
| line = self.random_file.__next__() | |
| return line[:-1].split("\t")[1] | |