| |
| |
| |
| |
|
|
| from dataclasses import dataclass, field |
|
|
| from fairseq.data.encoders import register_tokenizer |
| from fairseq.dataclass import FairseqDataclass |
|
|
|
|
| @dataclass |
| class MosesTokenizerConfig(FairseqDataclass): |
| source_lang: str = field(default="en", metadata={"help": "source language"}) |
| target_lang: str = field(default="en", metadata={"help": "target language"}) |
| moses_no_dash_splits: bool = field( |
| default=False, metadata={"help": "don't apply dash split rules"} |
| ) |
| moses_no_escape: bool = field( |
| default=False, |
| metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."}, |
| ) |
|
|
|
|
| @register_tokenizer("moses", dataclass=MosesTokenizerConfig) |
| class MosesTokenizer(object): |
| def __init__(self, cfg: MosesTokenizerConfig): |
| self.cfg = cfg |
|
|
| try: |
| from sacremoses import MosesTokenizer, MosesDetokenizer |
|
|
| self.tok = MosesTokenizer(cfg.source_lang) |
| self.detok = MosesDetokenizer(cfg.target_lang) |
| except ImportError: |
| raise ImportError( |
| "Please install Moses tokenizer with: pip install sacremoses" |
| ) |
|
|
| def encode(self, x: str) -> str: |
| return self.tok.tokenize( |
| x, |
| aggressive_dash_splits=(not self.cfg.moses_no_dash_splits), |
| return_str=True, |
| escape=(not self.cfg.moses_no_escape), |
| ) |
|
|
| def decode(self, x: str) -> str: |
| return self.detok.detokenize(x.split()) |
|
|