multilingual-nlp / sitecustomize.py
chmielvu's picture
Add tokenizer compatibility shim
0b65f6d verified
from pathlib import Path
try:
import huggingface_hub
except Exception: # pragma: no cover
huggingface_hub = None
try:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
except Exception: # pragma: no cover
PreTrainedTokenizer = None
PreTrainedTokenizerBase = None
PreTrainedTokenizerFast = None
if huggingface_hub is not None and not hasattr(huggingface_hub, "HfFolder"):
class HfFolder:
path_token = Path.home() / ".cache" / "huggingface" / "token"
@classmethod
def save_token(cls, token: str) -> None:
cls.path_token.parent.mkdir(parents=True, exist_ok=True)
cls.path_token.write_text(token, encoding="utf-8")
@classmethod
def get_token(cls) -> str | None:
if cls.path_token.exists():
return cls.path_token.read_text(encoding="utf-8").strip() or None
return None
@classmethod
def delete_token(cls) -> None:
if cls.path_token.exists():
cls.path_token.unlink()
huggingface_hub.HfFolder = HfFolder
def _compat_batch_encode_plus(self, *args, **kwargs):
return self(*args, **kwargs)
for tokenizer_cls in (PreTrainedTokenizerBase, PreTrainedTokenizer, PreTrainedTokenizerFast):
if tokenizer_cls is not None and not hasattr(tokenizer_cls, "batch_encode_plus"):
tokenizer_cls.batch_encode_plus = _compat_batch_encode_plus