| |
| |
| |
| |
|
|
|
|
| from dataclasses import dataclass, field |
|
|
| from fairseq import file_utils |
| from fairseq.data.encoders import register_bpe |
| from fairseq.data.encoders.byte_utils import ( |
| SPACE, |
| SPACE_ESCAPE, |
| byte_encode, |
| smart_byte_decode, |
| ) |
| from fairseq.dataclass import FairseqDataclass |
|
|
|
|
| @dataclass |
| class ByteBpeConfig(FairseqDataclass): |
| sentencepiece_model_path: str = field( |
| default="???", metadata={"help": "path to sentencepiece model"} |
| ) |
|
|
|
|
| @register_bpe("byte_bpe", dataclass=ByteBpeConfig) |
| class ByteBPE(object): |
| def __init__(self, cfg): |
| vocab = file_utils.cached_path(cfg.sentencepiece_model_path) |
| try: |
| import sentencepiece as spm |
|
|
| self.sp = spm.SentencePieceProcessor() |
| self.sp.Load(vocab) |
| except ImportError: |
| raise ImportError( |
| "Please install sentencepiece with: pip install sentencepiece" |
| ) |
|
|
| def encode(self, x: str) -> str: |
| byte_encoded = byte_encode(x) |
| return SPACE.join(self.sp.EncodeAsPieces(byte_encoded)) |
|
|
| @staticmethod |
| def decode(x: str) -> str: |
| unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) |
| return smart_byte_decode(unescaped) |
|
|