from typing import List, Optional, Union from transformers import PreTrainedTokenizerFast from tokenizers.processors import TemplateProcessing from tokenizers import Tokenizer from transformers.tokenization_utils_base import BatchEncoding, EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy from transformers.utils import PaddingStrategy, TensorType import torch import numpy as np import pdb def create_tokenizer_custom(file): with open(file, 'r') as f: return Tokenizer.from_str(f.read()) class iPLMTokenizer(PreTrainedTokenizerFast): def __init__(self, parallel=False, **kwargs): super().__init__(tokenizer_object=create_tokenizer_custom(kwargs.get('tokenizer_file')), **kwargs) self.parallel = parallel def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, n_queries = -1, # -1 for vary-length prompt, int with larger than 0 for fix-length, 0 for no prompt text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, text_pair_target: Optional[ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] ] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, mutation_site: Optional[int] = -1, **kwargs, ) -> BatchEncoding: if not isinstance(text, list): text = [text] # add prompt text_with_prompt = [] for t in text: prompt_length = 0 assert '|' in t, 'prompt not found' raw_text = t.split('|')[-1] if n_queries > 0: # use fix length prompt prompt_length = n_queries elif n_queries < 0: prompt_length = len(raw_text) # 突变位点mask if mutation_site != -1 and mutation_site < len(raw_text): raw_text = list(raw_text) raw_text[mutation_site] = '' raw_text = ''.join(raw_text) elif mutation_site != -1 and mutation_site >= len(raw_text): raise ValueError('mutation site out of range') else: raw_text = raw_text text_with_prompt.append('' * prompt_length + raw_text) batch = super().__call__( text=text_with_prompt, text_pair=text_pair, text_target=text_target, text_pair_target=text_pair_target, add_special_tokens=add_special_tokens, padding=padding, truncation= truncation, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, padding_side=None, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs ) # add structure ids for i in range(len(text)): if '|' not in text[i]: continue structure_ids = text[i].split('|')[0] if return_tensors is None: for j in range(len(structure_ids)): batch['input_ids'][i][j] = ord(structure_ids[j]) else: batch['input_ids'][i, :len(structure_ids)] = torch.tensor([ord(c) for c in structure_ids]) if "token_type_ids" in batch: del batch["token_type_ids"] return batch