File size: 3,944 Bytes
220d808 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# -*- coding: utf-8 -*-
# file: omnigenbench_wrapper.py
# time: 00:57 27/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2024. All Rights Reserved.
import warnings
from transformers import AutoTokenizer
from omnigenbench import OmniKmersTokenizer
class Tokenizer(OmniKmersTokenizer):
def __init__(
self, base_tokenizer=None, k=3, overlap=0, max_length=512, t2u=True, **kwargs
):
super(Tokenizer, self).__init__(base_tokenizer, t2u=t2u, **kwargs)
self.k = k
self.overlap = overlap
self.max_length = max_length
self.metadata["tokenizer_name"] = self.__class__.__name__
def __call__(self, sequence, **kwargs):
if self.u2t:
sequence = "".join([seq.replace("U", "T").upper() for seq in sequence])
if self.t2u:
sequence = "".join([seq.replace("T", "U").upper() for seq in sequence])
sequence_tokens = self.tokenize(sequence)[
: kwargs.get("max_length", self.max_length) - 2
]
tokenized_inputs = {
"input_ids": [],
"attention_mask": [],
}
bos_id = (
self.base_tokenizer.bos_token_id
if self.base_tokenizer.bos_token_id is not None
else self.base_tokenizer.cls_token_id
)
eos_id = (
self.base_tokenizer.eos_token_id
if self.base_tokenizer.eos_token_id is not None
else self.base_tokenizer.sep_token_id
)
for tokens in sequence_tokens:
tokenized_inputs["input_ids"].append(
[bos_id] + self.base_tokenizer.convert_tokens_to_ids(tokens) + [eos_id]
)
tokenized_inputs["attention_mask"].append(
[1] * len(tokenized_inputs["input_ids"][-1])
)
for i, ids in enumerate(tokenized_inputs["input_ids"]):
if ids.count(self.base_tokenizer.unk_token_id) / len(ids) > 0.1:
warnings.warn(
f"Unknown tokens are more than 10% in the {i}th sequence, please check the tokenization process."
)
tokenized_inputs = self.base_tokenizer.pad(
tokenized_inputs,
padding="max_length",
max_length=self.max_length
if not kwargs.get("max_length", None)
else kwargs.get("max_length"),
pad_to_multiple_of=self.max_length
if not kwargs.get("max_length", None)
else kwargs.get("max_length"),
return_attention_mask=True,
return_tensors="pt",
)
return tokenized_inputs
@staticmethod
def from_pretrained(model_name_or_path, **kwargs):
self = OmniKmersTokenizer(
AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
)
return self
def tokenize(self, sequence, **kwargs):
if isinstance(sequence, str):
sequences = [sequence]
else:
sequences = sequence
sequence_tokens = []
for i in range(len(sequences)):
tokens = []
for j in range(0, len(sequences[i]), self.k - self.overlap):
tokens.append(sequences[i][j : j + self.k])
sequence_tokens.append(tokens)
return sequence_tokens
def encode(self, input_ids, **kwargs):
return self.base_tokenizer.encode(input_ids, **kwargs)
def decode(self, input_ids, **kwargs):
return self.base_tokenizer.decode(input_ids, **kwargs)
def encode_plus(self, sequence, **kwargs):
raise NotImplementedError("The encode_plus() function is not implemented yet.")
|