Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from collections import Counter | |
| from collections import OrderedDict | |
| from typing import List | |
| import torch | |
| from ordered_set import OrderedSet | |
| from transformers import AutoTokenizer | |
| from common.utils import download, unzip_file | |
| def get_tokenizer(tokenizer_name:str): | |
| """auto get tokenizer | |
| Args: | |
| tokenizer_name (str): support "word_tokenizer" and other pretrained tokenizer in hugging face. | |
| Returns: | |
| Any: Tokenizer Object | |
| """ | |
| if tokenizer_name == "word_tokenizer": | |
| return WordTokenizer(tokenizer_name) | |
| else: | |
| return AutoTokenizer.from_pretrained(tokenizer_name) | |
| def get_tokenizer_class(tokenizer_name:str): | |
| """auto get tokenizer class | |
| Args: | |
| tokenizer_name (str): support "word_tokenizer" and other pretrained tokenizer in hugging face. | |
| Returns: | |
| Any: Tokenizer Class | |
| """ | |
| if tokenizer_name == "word_tokenizer": | |
| return WordTokenizer | |
| else: | |
| return AutoTokenizer.from_pretrained | |
| BATCH_STATE = 1 | |
| INSTANCE_STATE = 2 | |
| class WordTokenizer(object): | |
| def __init__(self, name): | |
| self.__name = name | |
| self.index2instance = OrderedSet() | |
| self.instance2index = OrderedDict() | |
| # Counter Object record the frequency | |
| # of element occurs in raw text. | |
| self.counter = Counter() | |
| self.__sign_pad = "[PAD]" | |
| self.add_instance(self.__sign_pad) | |
| self.__sign_unk = "[UNK]" | |
| self.add_instance(self.__sign_unk) | |
| def padding_side(self): | |
| return "right" | |
| def all_special_ids(self): | |
| return [self.unk_token_id, self.pad_token_id] | |
| def name_or_path(self): | |
| return self.__name | |
| def vocab_size(self): | |
| return len(self.instance2index) | |
| def pad_token_id(self): | |
| return self.instance2index[self.__sign_pad] | |
| def unk_token_id(self): | |
| return self.instance2index[self.__sign_unk] | |
| def add_instance(self, instance): | |
| """ Add instances to alphabet. | |
| 1, We support any iterative data structure which | |
| contains elements of str type. | |
| 2, We will count added instances that will influence | |
| the serialization of unknown instance. | |
| Args: | |
| instance: is given instance or a list of it. | |
| """ | |
| if isinstance(instance, (list, tuple)): | |
| for element in instance: | |
| self.add_instance(element) | |
| return | |
| # We only support elements of str type. | |
| assert isinstance(instance, str) | |
| # count the frequency of instances. | |
| # self.counter[instance] += 1 | |
| if instance not in self.index2instance: | |
| self.instance2index[instance] = len(self.index2instance) | |
| self.index2instance.append(instance) | |
| def __call__(self, instance, | |
| return_tensors="pt", | |
| is_split_into_words=True, | |
| padding=True, | |
| add_special_tokens=False, | |
| truncation=True, | |
| max_length=512, | |
| **config): | |
| if isinstance(instance, (list, tuple)) and isinstance(instance[0], (str)) and is_split_into_words: | |
| res = self.get_index(instance) | |
| state = INSTANCE_STATE | |
| elif isinstance(instance, str) and not is_split_into_words: | |
| res = self.get_index(instance.split(" ")) | |
| state = INSTANCE_STATE | |
| elif not is_split_into_words and isinstance(instance, (list, tuple)): | |
| res = [self.get_index(ins.split(" ")) for ins in instance] | |
| state = BATCH_STATE | |
| else: | |
| res = [self.get_index(ins) for ins in instance] | |
| state = BATCH_STATE | |
| res = [r[:max_length] if len(r) >= max_length else r for r in res] | |
| pad_id = self.get_index(self.__sign_pad) | |
| if padding and state == BATCH_STATE: | |
| max_len = max([len(x) for x in instance]) | |
| for i in range(len(res)): | |
| res[i] = res[i] + [pad_id] * (max_len - len(res[i])) | |
| if return_tensors == "pt": | |
| input_ids = torch.Tensor(res).long() | |
| attention_mask = (input_ids != pad_id).long() | |
| elif state == BATCH_STATE: | |
| input_ids = res | |
| attention_mask = [1 if r != pad_id else 0 for batch in res for r in batch] | |
| else: | |
| input_ids = res | |
| attention_mask = [1 if r != pad_id else 0 for r in res] | |
| return TokenizedData(input_ids, token_type_ids=attention_mask, attention_mask=attention_mask) | |
| def get_index(self, instance): | |
| """ Serialize given instance and return. | |
| For unknown words, the return index of alphabet | |
| depends on variable self.__use_unk: | |
| 1, If True, then return the index of "<UNK>"; | |
| 2, If False, then return the index of the | |
| element that hold max frequency in training data. | |
| Args: | |
| instance (Any): is given instance or a list of it. | |
| Return: | |
| Any: the serialization of query instance. | |
| """ | |
| if isinstance(instance, (list, tuple)): | |
| return [self.get_index(elem) for elem in instance] | |
| assert isinstance(instance, str) | |
| try: | |
| return self.instance2index[instance] | |
| except KeyError: | |
| return self.instance2index[self.__sign_unk] | |
| def decode(self, index): | |
| """ Get corresponding instance of query index. | |
| if index is invalid, then throws exception. | |
| Args: | |
| index (int): is query index, possibly iterable. | |
| Returns: | |
| is corresponding instance. | |
| """ | |
| if isinstance(index, list): | |
| return [self.decode(elem) for elem in index] | |
| if isinstance(index, torch.Tensor): | |
| index = index.tolist() | |
| return self.decode(index) | |
| return self.index2instance[index] | |
| def decode_batch(self, index, **kargs): | |
| """ Get corresponding instance of query index. | |
| if index is invalid, then throws exception. | |
| Args: | |
| index (int): is query index, possibly iterable. | |
| Returns: | |
| is corresponding instance. | |
| """ | |
| return self.decode(index) | |
| def save(self, path): | |
| """ Save the content of alphabet to files. | |
| There are two kinds of saved files: | |
| 1, The first is a list file, elements are | |
| sorted by the frequency of occurrence. | |
| 2, The second is a dictionary file, elements | |
| are sorted by it serialized index. | |
| Args: | |
| path (str): is the path to save object. | |
| """ | |
| with open(path, 'w', encoding="utf8") as fw: | |
| fw.write(json.dumps({"name": self.__name, "token_map": self.instance2index})) | |
| def from_file(path): | |
| with open(path, 'r', encoding="utf8") as fw: | |
| obj = json.load(fw) | |
| tokenizer = WordTokenizer(obj["name"]) | |
| tokenizer.instance2index = OrderedDict(obj["token_map"]) | |
| # tokenizer.counter = len(tokenizer.instance2index) | |
| tokenizer.index2instance = OrderedSet(tokenizer.instance2index.keys()) | |
| return tokenizer | |
| def __len__(self): | |
| return len(self.index2instance) | |
| def __str__(self): | |
| return 'Alphabet {} contains about {} words: \n\t{}'.format(self.name_or_path, len(self), self.index2instance) | |
| def convert_tokens_to_ids(self, tokens): | |
| """convert token sequence to intput ids sequence | |
| Args: | |
| tokens (Any): token sequence | |
| Returns: | |
| Any: intput ids sequence | |
| """ | |
| try: | |
| if isinstance(tokens, (list, tuple)): | |
| return [self.instance2index[x] for x in tokens] | |
| return self.instance2index[tokens] | |
| except KeyError: | |
| return self.instance2index[self.__sign_unk] | |
| class TokenizedData(): | |
| """tokenized output data with input_ids, token_type_ids, attention_mask | |
| """ | |
| def __init__(self, input_ids, token_type_ids, attention_mask): | |
| self.input_ids = input_ids | |
| self.token_type_ids = token_type_ids | |
| self.attention_mask = attention_mask | |
| def word_ids(self, index: int) -> List[int or None]: | |
| """ get word id list | |
| Args: | |
| index (int): word index in sequence | |
| Returns: | |
| List[int or None]: word id list | |
| """ | |
| return [j if self.attention_mask[index][j] != 0 else None for j, x in enumerate(self.input_ids[index])] | |
| def word_to_tokens(self, index, word_id, **kwargs): | |
| """map word and tokens | |
| Args: | |
| index (int): unused | |
| word_id (int): word index in sequence | |
| """ | |
| return (word_id, word_id + 1) | |
| def to(self, device): | |
| """set device | |
| Args: | |
| device (str): support ["cpu", "cuda"] | |
| """ | |
| self.input_ids = self.input_ids.to(device) | |
| self.token_type_ids = self.token_type_ids.to(device) | |
| self.attention_mask = self.attention_mask.to(device) | |
| return self | |
| def load_embedding(tokenizer: WordTokenizer, glove_name:str): | |
| """ load embedding from standford server or local cache. | |
| Args: | |
| tokenizer (WordTokenizer): non-pretrained tokenizer | |
| glove_name (str): _description_ | |
| Returns: | |
| Any: word embedding | |
| """ | |
| save_path = "save/" + glove_name + ".zip" | |
| if not os.path.exists(save_path): | |
| download("http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip#" + glove_name, save_path) | |
| unzip_file(save_path, "save/" + glove_name) | |
| dim = int(glove_name.split(".")[-2][:-1]) | |
| embedding_list = torch.rand((tokenizer.vocab_size, dim)) | |
| embedding_list[tokenizer.pad_token_id] = torch.zeros((1, dim)) | |
| with open("save/" + glove_name + "/" + glove_name, "r", encoding="utf8") as f: | |
| for line in f.readlines(): | |
| datas = line.split(" ") | |
| word = datas[0] | |
| embedding = torch.Tensor([float(datas[i + 1]) for i in range(len(datas) - 1)]) | |
| tokenized = tokenizer.convert_tokens_to_ids(word) | |
| if isinstance(tokenized, int) and tokenized != tokenizer.unk_token_id: | |
| embedding_list[tokenized] = embedding | |
| return embedding_list | |