| | |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from megatron import get_args, get_tokenizer |
| | from megatron.data.bert_dataset import build_training_sample |
| |
|
| |
|
| | class BertEmbeddingDataset(torch.utils.data.Dataset): |
| | '''Dataset to convert a text dataset to Bert tokens.''' |
| |
|
| | def __init__(self, text_dataset, max_seq_length): |
| |
|
| | super().__init__() |
| |
|
| | args = get_args() |
| |
|
| | |
| | self.text_dataset = text_dataset |
| | self.bert_tokenizer = get_tokenizer() |
| |
|
| | |
| | self.max_seq_length = max_seq_length |
| | self.seed = args.seed |
| | self.masked_lm_prob = args.mask_prob |
| |
|
| | |
| | self.vocab_id_list = list(self.bert_tokenizer.inv_vocab.keys()) |
| | self.vocab_id_to_token_dict = self.bert_tokenizer.inv_vocab |
| | self.cls_id = self.bert_tokenizer.cls |
| | self.sep_id = self.bert_tokenizer.sep |
| | self.mask_id = self.bert_tokenizer.mask |
| | self.pad_id = self.bert_tokenizer.pad |
| |
|
| | def __len__(self): |
| | return len(self.text_dataset) |
| |
|
| | def __getitem__(self, idx): |
| |
|
| | |
| | text_sample = self.text_dataset[idx] |
| | text = text_sample["text"] |
| | text = text.replace("<|endoftext|>", "") |
| |
|
| | |
| | bert_token_ids = self.bert_tokenizer.tokenize(text) |
| | bert_token_ids = bert_token_ids[:self.max_seq_length - 2] |
| | if not bert_token_ids: |
| | bert_token_ids = [ self.bert_tokenizer.pad_id ] |
| |
|
| | |
| | |
| | |
| | np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) |
| |
|
| | |
| | sample = build_training_sample([bert_token_ids], |
| | len(bert_token_ids), |
| | len(bert_token_ids) + 2, |
| | self.vocab_id_list, |
| | self.vocab_id_to_token_dict, |
| | self.cls_id, self.sep_id, |
| | self.mask_id, self.pad_id, |
| | self.masked_lm_prob, np_rng, |
| | binary_head=False) |
| | sample["seq_length"] = len(sample["text"]) |
| | return sample |
| |
|