| import torch |
| import json |
| from transformers.models.bart.modeling_bart import shift_tokens_right |
| from torch.utils.data import Dataset |
|
|
| class GlycoBertTokenizer: |
| def __init__(self, vocab_list, max_seq_length=512): |
| |
| self.special_tokens = { |
| 'pad_token': '[PAD]', |
| 'cls_token': '[CLS]', |
| 'sep_token': '[SEP]', |
| 'unk_token': '[UNK]', |
| 'mask_token': '[MASK]' |
| } |
| |
| |
| special_token_symbols = list(self.special_tokens.values()) |
|
|
| |
| vocab_list = [word for word in vocab_list if word not in special_token_symbols] |
|
|
| |
| combined_list = special_token_symbols + vocab_list |
|
|
| |
| self.vocab = {word: idx for idx, word in enumerate(combined_list)} |
| self.reverse_vocab = {idx: word for word, idx in self.vocab.items()} |
| self.max_seq_length = max_seq_length |
|
|
| def tokenize(self, text): |
| return text.split() |
|
|
| def encode(self, texts): |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| batch_token_ids = [] |
| batch_attention_masks = [] |
| |
| for text in texts: |
| tokens = self.tokenize(text) |
| token_ids = [self.vocab.get(token, self.vocab[self.special_tokens['unk_token']]) for token in tokens] |
|
|
| |
| token_ids = [self.vocab[self.special_tokens['cls_token']]] + token_ids + [self.vocab[self.special_tokens['sep_token']]] |
|
|
| |
| attention_mask = [1] * len(token_ids) |
| |
| |
| if len(token_ids) < self.max_seq_length: |
| padding_length = self.max_seq_length - len(token_ids) |
| token_ids += [self.vocab[self.special_tokens['pad_token']]] * padding_length |
| attention_mask += [0] * padding_length |
| else: |
| token_ids = token_ids[:self.max_seq_length] |
| attention_mask = attention_mask[:self.max_seq_length] |
|
|
| batch_token_ids.append(torch.tensor(token_ids)) |
| batch_attention_masks.append(torch.tensor(attention_mask)) |
|
|
| return { |
| "token_ids": torch.stack(batch_token_ids), |
| "attention_mask": torch.stack(batch_attention_masks) |
| } |
|
|
| def decode(self, batch_token_ids, skip_special_tokens=False): |
| if batch_token_ids.dim() == 1: |
| batch_token_ids = batch_token_ids.unsqueeze(0) |
|
|
| decoded_texts = [] |
| for token_ids in batch_token_ids: |
| if skip_special_tokens: |
| decoded_texts.append(' '.join([self.reverse_vocab[token_id.item()] for token_id in token_ids if token_id.item() not in [self.vocab[val] for val in self.special_tokens.values()]])) |
| else: |
| decoded_texts.append(' '.join([self.reverse_vocab[token_id.item()] for token_id in token_ids if token_id.item() != self.vocab[self.special_tokens['pad_token']]])) |
|
|
| return decoded_texts if len(decoded_texts) > 1 else decoded_texts[0] |
|
|
| |
| def save_vocabulary(self, path="vocab.json"): |
| with open(path, 'w') as file: |
| json.dump(self.vocab, file) |
|
|
| @property |
| def vocab_size(self): |
| """Returns the size of the vocabulary.""" |
| return len(self.vocab) |
|
|
| @classmethod |
| def load_vocabulary(cls, path="vocab.json", max_seq_length=512): |
| with open(path, 'r') as file: |
| loaded_vocab = json.load(file) |
| return cls(list(loaded_vocab.keys()), max_seq_length) |
| |
|
|
| class GlycoBartTokenizer: |
| def __init__(self, vocab_list, max_seq_length=512): |
| |
| self.special_tokens = { |
| 'pad_token': '<pad>', |
| 'bos_token': '<s>', |
| 'eos_token': '</s>', |
| 'sep_token': '<sep>', |
| 'cls_token': '<cls>', |
| 'unk_token': '<unk>', |
| 'mask_token': '<mask>' |
| } |
| |
| |
| special_token_symbols = list(self.special_tokens.values()) |
|
|
| |
| vocab_list = [word for word in vocab_list if word not in special_token_symbols] |
|
|
| |
| combined_list = special_token_symbols + vocab_list |
|
|
| |
| self.vocab = {word: idx for idx, word in enumerate(combined_list)} |
| self.reverse_vocab = {idx: word for word, idx in self.vocab.items()} |
| self.max_seq_length = max_seq_length |
|
|
| def tokenize(self, text): |
| return text.split() |
|
|
| def encode(self, texts): |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| batch_token_ids = [] |
| batch_attention_masks = [] |
| |
| for text in texts: |
| tokens = self.tokenize(text) |
| token_ids = [self.vocab.get(token, self.vocab[self.special_tokens['unk_token']]) for token in tokens] |
| |
| |
| token_ids = [self.vocab[self.special_tokens['bos_token']]] + token_ids + [self.vocab[self.special_tokens['eos_token']]] |
| |
| |
| attention_mask = [1] * len(token_ids) |
| |
| |
| if len(token_ids) < self.max_seq_length: |
| padding_length = self.max_seq_length - len(token_ids) |
| token_ids += [self.vocab[self.special_tokens['pad_token']]] * padding_length |
| attention_mask += [0] * padding_length |
| else: |
| token_ids = token_ids[:self.max_seq_length] |
| attention_mask = attention_mask[:self.max_seq_length] |
|
|
| batch_token_ids.append(torch.tensor(token_ids)) |
| batch_attention_masks.append(torch.tensor(attention_mask)) |
|
|
| return { |
| "token_ids": torch.stack(batch_token_ids), |
| "attention_mask": torch.stack(batch_attention_masks) |
| } |
|
|
| def decode(self, batch_token_ids, skip_special_tokens=False): |
| if batch_token_ids.dim() == 1: |
| batch_token_ids = batch_token_ids.unsqueeze(0) |
|
|
| decoded_texts = [] |
| for token_ids in batch_token_ids: |
| if skip_special_tokens: |
| decoded_texts.append(' '.join([self.reverse_vocab[token_id.item()] for token_id in token_ids if token_id.item() not in [self.vocab[val] for val in self.special_tokens.values()]])) |
| else: |
| decoded_texts.append(' '.join([self.reverse_vocab[token_id.item()] for token_id in token_ids if token_id.item() != self.vocab[self.special_tokens['pad_token']]])) |
|
|
| return decoded_texts if len(decoded_texts) > 1 else decoded_texts[0] |
|
|
| |
| def save_vocabulary(self, path="vocab.json"): |
| with open(path, 'w') as file: |
| json.dump(self.vocab, file) |
|
|
| @property |
| def vocab_size(self): |
| """Returns the size of the vocabulary.""" |
| return len(self.vocab) |
|
|
| @classmethod |
| def load_vocabulary(cls, path="vocab.json", max_seq_length=512): |
| with open(path, 'r') as file: |
| loaded_vocab = json.load(file) |
| return cls(list(loaded_vocab.keys()), max_seq_length) |
|
|
| |
| class GlycanTranslationData(Dataset): |
| def __init__(self, input_corpus, output_corpus, pad_token_id, eos_token_id): |
| self.input_ids = input_corpus["token_ids"] |
| self.input_attention_masks = input_corpus["attention_mask"] |
| |
| self.output_ids = output_corpus["token_ids"] |
| self.output_attention_masks = output_corpus["attention_mask"] |
|
|
| |
| self.pad_token_id = pad_token_id |
| self.eos_token_id = eos_token_id |
| |
| def __len__(self): |
| return len(self.input_ids) |
| |
| def __getitem__(self, idx): |
| |
| output_ids_for_idx = self.output_ids[idx] |
| |
| |
| if len(output_ids_for_idx.shape) == 1: |
| output_ids_for_idx = output_ids_for_idx.unsqueeze(0) |
|
|
| |
| decoder_input_ids = shift_tokens_right(output_ids_for_idx, self.pad_token_id, self.eos_token_id).squeeze(0) |
|
|
| |
| |
| decoder_attention_mask = torch.cat([torch.tensor([1]), self.output_attention_masks[idx]])[:-1] |
| |
| return { |
| "input_ids": self.input_ids[idx], |
| "attention_mask": self.input_attention_masks[idx], |
| "decoder_input_ids": decoder_input_ids, |
| "decoder_attention_mask": self.output_attention_masks[idx], |
| "labels": self.output_ids[idx] |
| } |