|
|
import torch |
|
|
import json |
|
|
from transformers.models.bart.modeling_bart import shift_tokens_right |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
class GlycoBertTokenizer: |
|
|
def __init__(self, vocab_list, max_seq_length=512): |
|
|
|
|
|
self.special_tokens = { |
|
|
'pad_token': '[PAD]', |
|
|
'cls_token': '[CLS]', |
|
|
'sep_token': '[SEP]', |
|
|
'unk_token': '[UNK]', |
|
|
'mask_token': '[MASK]' |
|
|
} |
|
|
|
|
|
|
|
|
special_token_symbols = list(self.special_tokens.values()) |
|
|
|
|
|
|
|
|
vocab_list = [word for word in vocab_list if word not in special_token_symbols] |
|
|
|
|
|
|
|
|
combined_list = special_token_symbols + vocab_list |
|
|
|
|
|
|
|
|
self.vocab = {word: idx for idx, word in enumerate(combined_list)} |
|
|
self.reverse_vocab = {idx: word for word, idx in self.vocab.items()} |
|
|
self.max_seq_length = max_seq_length |
|
|
|
|
|
def tokenize(self, text): |
|
|
return text.split() |
|
|
|
|
|
def encode(self, texts): |
|
|
if isinstance(texts, str): |
|
|
texts = [texts] |
|
|
|
|
|
batch_token_ids = [] |
|
|
batch_attention_masks = [] |
|
|
|
|
|
for text in texts: |
|
|
tokens = self.tokenize(text) |
|
|
token_ids = [self.vocab.get(token, self.vocab[self.special_tokens['unk_token']]) for token in tokens] |
|
|
|
|
|
|
|
|
token_ids = [self.vocab[self.special_tokens['cls_token']]] + token_ids + [self.vocab[self.special_tokens['sep_token']]] |
|
|
|
|
|
|
|
|
attention_mask = [1] * len(token_ids) |
|
|
|
|
|
|
|
|
if len(token_ids) < self.max_seq_length: |
|
|
padding_length = self.max_seq_length - len(token_ids) |
|
|
token_ids += [self.vocab[self.special_tokens['pad_token']]] * padding_length |
|
|
attention_mask += [0] * padding_length |
|
|
else: |
|
|
token_ids = token_ids[:self.max_seq_length] |
|
|
attention_mask = attention_mask[:self.max_seq_length] |
|
|
|
|
|
batch_token_ids.append(torch.tensor(token_ids)) |
|
|
batch_attention_masks.append(torch.tensor(attention_mask)) |
|
|
|
|
|
return { |
|
|
"token_ids": torch.stack(batch_token_ids), |
|
|
"attention_mask": torch.stack(batch_attention_masks) |
|
|
} |
|
|
|
|
|
def decode(self, batch_token_ids, skip_special_tokens=False): |
|
|
if batch_token_ids.dim() == 1: |
|
|
batch_token_ids = batch_token_ids.unsqueeze(0) |
|
|
|
|
|
decoded_texts = [] |
|
|
for token_ids in batch_token_ids: |
|
|
if skip_special_tokens: |
|
|
decoded_texts.append(' '.join([self.reverse_vocab[token_id.item()] for token_id in token_ids if token_id.item() not in [self.vocab[val] for val in self.special_tokens.values()]])) |
|
|
else: |
|
|
decoded_texts.append(' '.join([self.reverse_vocab[token_id.item()] for token_id in token_ids if token_id.item() != self.vocab[self.special_tokens['pad_token']]])) |
|
|
|
|
|
return decoded_texts if len(decoded_texts) > 1 else decoded_texts[0] |
|
|
|
|
|
|
|
|
def save_vocabulary(self, path="vocab.json"): |
|
|
with open(path, 'w') as file: |
|
|
json.dump(self.vocab, file) |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
"""Returns the size of the vocabulary.""" |
|
|
return len(self.vocab) |
|
|
|
|
|
@classmethod |
|
|
def load_vocabulary(cls, path="vocab.json", max_seq_length=512): |
|
|
with open(path, 'r') as file: |
|
|
loaded_vocab = json.load(file) |
|
|
return cls(list(loaded_vocab.keys()), max_seq_length) |
|
|
|
|
|
|
|
|
class GlycoBartTokenizer: |
|
|
def __init__(self, vocab_list, max_seq_length=512): |
|
|
|
|
|
self.special_tokens = { |
|
|
'pad_token': '<pad>', |
|
|
'bos_token': '<s>', |
|
|
'eos_token': '</s>', |
|
|
'sep_token': '<sep>', |
|
|
'cls_token': '<cls>', |
|
|
'unk_token': '<unk>', |
|
|
'mask_token': '<mask>' |
|
|
} |
|
|
|
|
|
|
|
|
special_token_symbols = list(self.special_tokens.values()) |
|
|
|
|
|
|
|
|
vocab_list = [word for word in vocab_list if word not in special_token_symbols] |
|
|
|
|
|
|
|
|
combined_list = special_token_symbols + vocab_list |
|
|
|
|
|
|
|
|
self.vocab = {word: idx for idx, word in enumerate(combined_list)} |
|
|
self.reverse_vocab = {idx: word for word, idx in self.vocab.items()} |
|
|
self.max_seq_length = max_seq_length |
|
|
|
|
|
def tokenize(self, text): |
|
|
return text.split() |
|
|
|
|
|
def encode(self, texts): |
|
|
if isinstance(texts, str): |
|
|
texts = [texts] |
|
|
|
|
|
batch_token_ids = [] |
|
|
batch_attention_masks = [] |
|
|
|
|
|
for text in texts: |
|
|
tokens = self.tokenize(text) |
|
|
token_ids = [self.vocab.get(token, self.vocab[self.special_tokens['unk_token']]) for token in tokens] |
|
|
|
|
|
|
|
|
token_ids = [self.vocab[self.special_tokens['bos_token']]] + token_ids + [self.vocab[self.special_tokens['eos_token']]] |
|
|
|
|
|
|
|
|
attention_mask = [1] * len(token_ids) |
|
|
|
|
|
|
|
|
if len(token_ids) < self.max_seq_length: |
|
|
padding_length = self.max_seq_length - len(token_ids) |
|
|
token_ids += [self.vocab[self.special_tokens['pad_token']]] * padding_length |
|
|
attention_mask += [0] * padding_length |
|
|
else: |
|
|
token_ids = token_ids[:self.max_seq_length] |
|
|
attention_mask = attention_mask[:self.max_seq_length] |
|
|
|
|
|
batch_token_ids.append(torch.tensor(token_ids)) |
|
|
batch_attention_masks.append(torch.tensor(attention_mask)) |
|
|
|
|
|
return { |
|
|
"token_ids": torch.stack(batch_token_ids), |
|
|
"attention_mask": torch.stack(batch_attention_masks) |
|
|
} |
|
|
|
|
|
def decode(self, batch_token_ids, skip_special_tokens=False): |
|
|
if batch_token_ids.dim() == 1: |
|
|
batch_token_ids = batch_token_ids.unsqueeze(0) |
|
|
|
|
|
decoded_texts = [] |
|
|
for token_ids in batch_token_ids: |
|
|
if skip_special_tokens: |
|
|
decoded_texts.append(' '.join([self.reverse_vocab[token_id.item()] for token_id in token_ids if token_id.item() not in [self.vocab[val] for val in self.special_tokens.values()]])) |
|
|
else: |
|
|
decoded_texts.append(' '.join([self.reverse_vocab[token_id.item()] for token_id in token_ids if token_id.item() != self.vocab[self.special_tokens['pad_token']]])) |
|
|
|
|
|
return decoded_texts if len(decoded_texts) > 1 else decoded_texts[0] |
|
|
|
|
|
|
|
|
def save_vocabulary(self, path="vocab.json"): |
|
|
with open(path, 'w') as file: |
|
|
json.dump(self.vocab, file) |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
"""Returns the size of the vocabulary.""" |
|
|
return len(self.vocab) |
|
|
|
|
|
@classmethod |
|
|
def load_vocabulary(cls, path="vocab.json", max_seq_length=512): |
|
|
with open(path, 'r') as file: |
|
|
loaded_vocab = json.load(file) |
|
|
return cls(list(loaded_vocab.keys()), max_seq_length) |
|
|
|
|
|
|
|
|
class GlycanTranslationData(Dataset): |
|
|
def __init__(self, input_corpus, output_corpus, pad_token_id, eos_token_id): |
|
|
self.input_ids = input_corpus["token_ids"] |
|
|
self.input_attention_masks = input_corpus["attention_mask"] |
|
|
|
|
|
self.output_ids = output_corpus["token_ids"] |
|
|
self.output_attention_masks = output_corpus["attention_mask"] |
|
|
|
|
|
|
|
|
self.pad_token_id = pad_token_id |
|
|
self.eos_token_id = eos_token_id |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.input_ids) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
output_ids_for_idx = self.output_ids[idx] |
|
|
|
|
|
|
|
|
if len(output_ids_for_idx.shape) == 1: |
|
|
output_ids_for_idx = output_ids_for_idx.unsqueeze(0) |
|
|
|
|
|
|
|
|
decoder_input_ids = shift_tokens_right(output_ids_for_idx, self.pad_token_id, self.eos_token_id).squeeze(0) |
|
|
|
|
|
|
|
|
|
|
|
decoder_attention_mask = torch.cat([torch.tensor([1]), self.output_attention_masks[idx]])[:-1] |
|
|
|
|
|
return { |
|
|
"input_ids": self.input_ids[idx], |
|
|
"attention_mask": self.input_attention_masks[idx], |
|
|
"decoder_input_ids": decoder_input_ids, |
|
|
"decoder_attention_mask": self.output_attention_masks[idx], |
|
|
"labels": self.output_ids[idx] |
|
|
} |