Spaces:
Runtime error
Runtime error
| import torch | |
| import random | |
| from torch.utils.data import Dataset | |
| class AbstractDataset(Dataset): | |
| special_tokens = {"bos_token": "<|BOS|>", | |
| "eos_token": "<|EOS|>", | |
| "unk_token": "<|UNK|>", | |
| "pad_token": "<|PAD|>", | |
| "sep_token": "<|SEP|>"} | |
| max_length = 1024 | |
| def __init__(self, data, tokenizer, randomize=True): | |
| title, text, keywords = [], [], [] | |
| for k, v in data.items(): | |
| title.append(v[0]) | |
| text.append(v[1]) | |
| keywords.append(v[2]) | |
| self.randomize = randomize | |
| self.tokenizer = tokenizer | |
| self.title = title | |
| self.text = text | |
| self.keywords = keywords | |
| def join_keywords(keywords, randomize=True): | |
| N = len(keywords) | |
| # random sampling and shuffle | |
| if randomize: | |
| # M = random.choice(range(N + 1)) | |
| # keywords = keywords[:M] | |
| random.shuffle(keywords) | |
| return ','.join(keywords) | |
| def __len__(self): | |
| return len(self.text) | |
| def __getitem__(self, i): | |
| keywords = self.keywords[i].copy() | |
| kw = self.join_keywords(keywords, self.randomize) | |
| input = self.special_tokens['bos_token'] + self.title[i] + \ | |
| self.special_tokens['sep_token'] + kw + self.special_tokens['sep_token'] + \ | |
| self.text[i] + self.special_tokens['eos_token'] | |
| encodings_dict = self.tokenizer(input, | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding="max_length") | |
| input_ids = encodings_dict['input_ids'] | |
| attention_mask = encodings_dict['attention_mask'] | |
| return {'label': torch.tensor(input_ids), | |
| 'input_ids': torch.tensor(input_ids), | |
| 'attention_mask': torch.tensor(attention_mask)} | |