| from enum import Enum |
| from typing import Dict, List, Optional, Sequence |
|
|
| import torch |
| from fairseq.data import Dictionary |
|
|
|
|
| class EncoderLangtok(Enum): |
| """ |
| Prepend to the beginning of source sentence either the |
| source or target language token. (src/tgt). |
| """ |
|
|
| src = "src" |
| tgt = "tgt" |
|
|
|
|
| class LangTokSpec(Enum): |
| main = "main" |
| mono_dae = "mono_dae" |
|
|
|
|
| class LangTokStyle(Enum): |
| multilingual = "multilingual" |
| mbart = "mbart" |
|
|
|
|
| @torch.jit.export |
| def get_lang_tok( |
| lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value |
| ) -> str: |
| |
| |
| TOKEN_STYLES: Dict[str, str] = { |
| LangTokStyle.mbart.value: "[{}]", |
| LangTokStyle.multilingual.value: "__{}__", |
| } |
|
|
| if spec.endswith("dae"): |
| lang = f"{lang}_dae" |
| elif spec.endswith("mined"): |
| lang = f"{lang}_mined" |
| style = TOKEN_STYLES[lang_tok_style] |
| return style.format(lang) |
|
|
|
|
| def augment_dictionary( |
| dictionary: Dictionary, |
| language_list: List[str], |
| lang_tok_style: str, |
| langtoks_specs: Sequence[str] = (LangTokSpec.main.value,), |
| extra_data: Optional[Dict[str, str]] = None, |
| ) -> None: |
| for spec in langtoks_specs: |
| for language in language_list: |
| dictionary.add_symbol( |
| get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec) |
| ) |
|
|
| if lang_tok_style == LangTokStyle.mbart.value or ( |
| extra_data is not None and LangTokSpec.mono_dae.value in extra_data |
| ): |
| dictionary.add_symbol("<mask>") |
|
|