| |
| |
| |
| |
| |
|
|
| from typing import Optional, Sequence, Set, final |
|
|
| from fairseq2.data.text import ( |
| SentencePieceDecoder, |
| SentencePieceEncoder, |
| SentencePieceModel, |
| TextTokenDecoder, |
| TextTokenEncoder, |
| TextTokenizer, |
| vocab_info_from_sentencepiece, |
| ) |
| from fairseq2.data.typing import PathLike |
| from fairseq2.typing import Device, finaloverride |
|
|
|
|
| @final |
| class SPMTokenizer(TextTokenizer): |
| """Represents standard SPM-based tokenizer used in MT tasks""" |
|
|
| model: SentencePieceModel |
| langs: Set[str] |
| prepend_target_langtok_to_target: bool |
|
|
| def __init__( |
| self, |
| pathname: PathLike, |
| langs: Sequence[str], |
| prepend_target_langtok_to_target: bool = True, |
| ) -> None: |
| """ |
| :param pathname: |
| The pathname of the SentencePiece model file. |
| :param langs: |
| The list of supported languages. |
| :param default_lang: |
| The fall-back language if no language is specified. |
| """ |
| self.langs = set(langs) |
| self.prepend_target_langtok_to_target = prepend_target_langtok_to_target |
|
|
| |
| control_symbols = [self._lang_tok_to_internal(lang) for lang in sorted(langs)] |
| self.model = SentencePieceModel(pathname, control_symbols) |
| vocab_info = vocab_info_from_sentencepiece(self.model) |
| super().__init__(vocab_info) |
|
|
| @classmethod |
| def _lang_tok_to_internal(cls, lang: str) -> str: |
| return f"__{lang}__" |
|
|
| @finaloverride |
| def create_encoder( |
| self, |
| *, |
| task: Optional[str] = None, |
| lang: Optional[str] = None, |
| mode: Optional[str] = None, |
| device: Optional[Device] = None, |
| pin_memory: bool = False, |
| ) -> TextTokenEncoder: |
| """Create a token encoder. |
| |
| :param task: |
| Must be 'translation'. If ``None``, defaults to 'translation'. |
| :param lang: |
| A language from :attr:`langs`. If ``None``, defaults to |
| :attr:`default_lang`. |
| :param mode: |
| Must be 'source' or 'target'. |
| :param device: |
| The device on which to construct tensors. |
| :param pin_memory: |
| If ``True``, uses pinned memory while constructing tensors. |
| """ |
| if task is not None and task != "translation": |
| raise ValueError(f"`task` must be 'translation', but is '{task}' instead.") |
|
|
| assert lang is not None |
|
|
| if lang not in self.langs: |
| raise ValueError( |
| f"`lang` must be a supported language, but is '{lang}' instead." |
| ) |
|
|
| if mode is None or mode == "source": |
| prefix_tokens = [] |
| suffix_tokens = ["</s>"] |
| elif mode == "target": |
| prefix_tokens = ( |
| ["</s>"] + [self._lang_tok_to_internal(lang)] |
| if self.prepend_target_langtok_to_target |
| else [] |
| ) |
| suffix_tokens = ["</s>"] |
| else: |
| raise ValueError( |
| f"`mode` must be 'source' or 'target', but is '{mode}' instead." |
| ) |
|
|
| return SentencePieceEncoder( |
| self.model, |
| prefix_tokens=prefix_tokens, |
| suffix_tokens=suffix_tokens, |
| device=device, |
| pin_memory=pin_memory, |
| ) |
|
|
| @finaloverride |
| def create_raw_encoder( |
| self, *, device: Optional[Device] = None, pin_memory: bool = False |
| ) -> TextTokenEncoder: |
| return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory) |
|
|
| @finaloverride |
| def create_decoder(self) -> TextTokenDecoder: |
| return SentencePieceDecoder(self.model) |
|
|