Spaces:
Runtime error
Runtime error
Dit-document-layout-analysis
/
unilm
/decoding
/GAD
/fairseq
/data
/multilingual
/multilingual_utils.py
| from enum import Enum | |
| from typing import Dict, List, Optional, Sequence | |
| import torch | |
| from fairseq.data import Dictionary | |
| class EncoderLangtok(Enum): | |
| """ | |
| Prepend to the beginning of source sentence either the | |
| source or target language token. (src/tgt). | |
| """ | |
| src = "src" | |
| tgt = "tgt" | |
| class LangTokSpec(Enum): | |
| main = "main" | |
| mono_dae = "mono_dae" | |
| class LangTokStyle(Enum): | |
| multilingual = "multilingual" | |
| mbart = "mbart" | |
| def get_lang_tok( | |
| lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value | |
| ) -> str: | |
| # TOKEN_STYLES can't be defined outside this fn since it needs to be | |
| # TorchScriptable. | |
| TOKEN_STYLES: Dict[str, str] = { | |
| LangTokStyle.mbart.value: "[{}]", | |
| LangTokStyle.multilingual.value: "__{}__", | |
| } | |
| if spec.endswith("dae"): | |
| lang = f"{lang}_dae" | |
| elif spec.endswith("mined"): | |
| lang = f"{lang}_mined" | |
| style = TOKEN_STYLES[lang_tok_style] | |
| return style.format(lang) | |
| def augment_dictionary( | |
| dictionary: Dictionary, | |
| language_list: List[str], | |
| lang_tok_style: str, | |
| langtoks_specs: Sequence[str] = (LangTokSpec.main.value,), | |
| extra_data: Optional[Dict[str, str]] = None, | |
| ) -> None: | |
| for spec in langtoks_specs: | |
| for language in language_list: | |
| dictionary.add_symbol( | |
| get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec) | |
| ) | |
| if lang_tok_style == LangTokStyle.mbart.value or ( | |
| extra_data is not None and LangTokSpec.mono_dae.value in extra_data | |
| ): | |
| dictionary.add_symbol("<mask>") | |