"""Remote-code tokenizer for Atom/Fusion GPT checkpoints. The tokenizer is intentionally HF-compatible: generic callers can use ``AutoTokenizer.from_pretrained(..., trust_remote_code=True)``. Arithmetic digit spans are reversed before tokenization so the model receives LSD-first numbers, matching pretraining. """ from __future__ import annotations import re from typing import Any from transformers import PreTrainedTokenizerFast class AtomTokenizer(PreTrainedTokenizerFast): vocab_files_names = {"tokenizer_file": "tokenizer.json"} model_input_names = ["input_ids", "attention_mask"] slow_tokenizer_class = None _digit_span_re = re.compile(r"\d+") def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.setdefault("bos_token", "<|bos|>") kwargs.setdefault("eos_token", "<|eos|>") kwargs.setdefault("unk_token", "<|unk|>") kwargs.setdefault("pad_token", "<|pad|>") super().__init__(*args, **kwargs) @classmethod def _reverse_digit_spans(cls, text: str) -> str: return cls._digit_span_re.sub(lambda match: match.group(0)[::-1], text) @classmethod def _transform_text(cls, value: Any) -> Any: if isinstance(value, str): return cls._reverse_digit_spans(value) if isinstance(value, tuple): return tuple(cls._transform_text(item) for item in value) if isinstance(value, list): return [cls._transform_text(item) for item in value] return value def __call__(self, text=None, text_pair=None, *args: Any, **kwargs: Any): return super().__call__( self._transform_text(text), self._transform_text(text_pair), *args, **kwargs, ) def encode(self, text, text_pair=None, *args: Any, **kwargs: Any): return super().encode( self._transform_text(text), self._transform_text(text_pair), *args, **kwargs, ) def batch_encode_plus(self, batch_text_or_text_pairs, *args: Any, **kwargs: Any): return super().batch_encode_plus( self._transform_text(batch_text_or_text_pairs), *args, **kwargs, ) def _decode(self, token_ids, skip_special_tokens: bool = False, **kwargs: Any) -> str: text = super()._decode( token_ids, skip_special_tokens=skip_special_tokens, **kwargs, ) return self._reverse_digit_spans(text)