File size: 528 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import LongformerTokenizerFast, AutoTokenizer, PreTrainedTokenizerFast


def get_tokenizer(model_str: str) -> PreTrainedTokenizerFast:
    if "longformer" in model_str:
        tokenizer = LongformerTokenizerFast.from_pretrained(
            model_str,
            add_prefix_space=True,
            clean_up_tokenization_spaces=True,
        )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            model_str, use_fast=True, clean_up_tokenization_spaces=True
        )

    return tokenizer