Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from collections import defaultdict | |
| from typing import Optional, Union, List | |
| from tokenizers import AddedToken, decoders, trainers | |
| from tokenizers import Tokenizer | |
| from tokenizers.models import WordPiece | |
| from tokenizers.normalizers import BertNormalizer | |
| from tokenizers.pre_tokenizers import BertPreTokenizer | |
| def generate_sentinel_tokens(num=100, start_id=0): | |
| tokens = [ | |
| AddedToken(content=f"[S_{i}]", single_word=True, normalized=False) | |
| for i in range(start_id, num + start_id) | |
| ] | |
| return tokens | |
| def generate_coord_tokens(bins=1000): | |
| """Extra tokens that are used for bounding box coordinates, | |
| xmin, ymin, xmax, ymax, but also other modalities like color | |
| maps, metadata, or poses. | |
| """ | |
| tokens = [] | |
| coords_str = ["v0={}", "v1={}", "v2={}", "v3={}"] | |
| for s in coords_str: | |
| for i in range(bins): | |
| tokens.append(AddedToken(content=s.format(i), single_word=True, normalized=False)) | |
| return tokens | |
| def generate_object_class_tokens(dataset="coco"): | |
| with open(os.path.join(os.path.dirname(__file__), 'object_classes.json')) as f: | |
| object_classes = json.load(f)[dataset] | |
| tokens = [ | |
| AddedToken(content=class_name, single_word=True, normalized=True) | |
| for class_name in object_classes | |
| ] | |
| return tokens | |
| def train_unified_wordpiece_tokenizer( | |
| files, | |
| vocab_size, | |
| sentinel_tokens: List[Union[str, AddedToken]] = None, | |
| coord_tokens: List[Union[str, AddedToken]] = None, | |
| object_class_tokens: List[Union[str, AddedToken]] = None, | |
| unk_token: Union[str, AddedToken] = "[UNK]", | |
| pad_token: Union[str, AddedToken] = "[PAD]", | |
| sos_token: Union[str, AddedToken] = "[SOS]", | |
| eos_token: Union[str, AddedToken] = "[EOS]", | |
| additional_special_tokens: List[Union[str, AddedToken]] = None, | |
| min_frequency=0, | |
| clean_text: bool = True, | |
| handle_chinese_chars: bool = True, | |
| strip_accents: Optional[bool] = None, | |
| lowercase: bool = True, | |
| wordpieces_prefix: str = "##", | |
| show_progress=True, | |
| ): | |
| tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token))) | |
| tokenizer.normalizer = BertNormalizer( | |
| clean_text=clean_text, | |
| handle_chinese_chars=handle_chinese_chars, | |
| strip_accents=strip_accents, | |
| lowercase=lowercase, | |
| ) | |
| tokenizer.pre_tokenizer = BertPreTokenizer() | |
| tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix) | |
| special_tokens = [] | |
| special_tokens.append(pad_token) | |
| special_tokens.append(unk_token) | |
| special_tokens.append(sos_token) | |
| special_tokens.append(eos_token) | |
| if sentinel_tokens is not None: | |
| special_tokens.extend(sentinel_tokens) | |
| if coord_tokens is not None: | |
| special_tokens.extend(coord_tokens) | |
| if object_class_tokens is not None: | |
| special_tokens.extend(object_class_tokens) | |
| if additional_special_tokens is not None: | |
| special_tokens.extend(additional_special_tokens) | |
| trainer = trainers.WordPieceTrainer( | |
| vocab_size=vocab_size, | |
| min_frequency=min_frequency, | |
| show_progress=show_progress, | |
| continuing_subword_prefix=wordpieces_prefix, | |
| special_tokens=special_tokens, | |
| ) | |
| if isinstance(files, str): | |
| files = [files] | |
| tokenizer.train(files, trainer=trainer) | |
| return tokenizer | |
| def get_sentinel_to_id_mapping(tokenizer, match_str="[S_"): | |
| sentinel_tokens = {k: v for k, v in tokenizer.get_vocab().items() if k.startswith(match_str)} | |
| # Extract the sentinel token id, the id is of the form "[S_0]", "[S_1]", etc. | |
| sentinel_to_id = {int(k.split("_")[1][:-1]): v for k, v in sorted(sentinel_tokens.items(), key=lambda x:x[1])} | |
| return sentinel_to_id | |
| def split_by_sentinel(seq_ids, sentinel_ids): | |
| splits = defaultdict(list) | |
| cur_sentinel = None | |
| for token in seq_ids: | |
| if token in sentinel_ids: | |
| cur_sentinel = token | |
| else: | |
| splits[cur_sentinel].append(token) | |
| return splits | |
| def merge_span_masking(input_seq, decoder_seq, sentinel_ids): | |
| decoder_splits = split_by_sentinel(decoder_seq, sentinel_ids) | |
| out_seq = [] | |
| for token in input_seq: | |
| if token in sentinel_ids: | |
| out_seq.extend(decoder_splits[token]) | |
| else: | |
| out_seq.append(token) | |
| return out_seq | |