| | import hashlib
|
| | import os
|
| | import uuid
|
| | from typing import List, Tuple, Union, Dict
|
| |
|
| | import regex as re
|
| | import sentencepiece as spm
|
| | from indicnlp.normalize import indic_normalize
|
| | from indicnlp.tokenize import indic_detokenize, indic_tokenize
|
| | from indicnlp.tokenize.sentence_tokenize import DELIM_PAT_NO_DANDA, sentence_split
|
| | from indicnlp.transliterate import unicode_transliterate
|
| | from mosestokenizer import MosesSentenceSplitter
|
| | from nltk.tokenize import sent_tokenize
|
| | from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer
|
| | from tqdm import tqdm
|
| |
|
| | from .flores_codes_map_indic import flores_codes, iso_to_flores
|
| | from .normalize_punctuation import punc_norm
|
| | from .normalize_regex_inference import EMAIL_PATTERN, normalize
|
| |
|
| |
|
| | def split_sentences(paragraph: str, lang: str) -> List[str]:
|
| | """
|
| | Splits the input text paragraph into sentences. It uses `moses` for English and
|
| | `indic-nlp` for Indic languages.
|
| |
|
| | Args:
|
| | paragraph (str): input text paragraph.
|
| | lang (str): flores language code.
|
| |
|
| | Returns:
|
| | List[str] -> list of sentences.
|
| | """
|
| | if lang == "eng_Latn":
|
| | with MosesSentenceSplitter(flores_codes[lang]) as splitter:
|
| | sents_moses = splitter([paragraph])
|
| | sents_nltk = sent_tokenize(paragraph)
|
| | if len(sents_nltk) < len(sents_moses):
|
| | sents = sents_nltk
|
| | else:
|
| | sents = sents_moses
|
| | return [sent.replace("\xad", "") for sent in sents]
|
| | else:
|
| | return sentence_split(paragraph, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA)
|
| |
|
| |
|
| | def add_token(sent: str, src_lang: str, tgt_lang: str, delimiter: str = " ") -> str:
|
| | """
|
| | Add special tokens indicating source and target language to the start of the input sentence.
|
| | The resulting string will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
|
| |
|
| | Args:
|
| | sent (str): input sentence to be translated.
|
| | src_lang (str): flores lang code of the input sentence.
|
| | tgt_lang (str): flores lang code in which the input sentence will be translated.
|
| | delimiter (str): separator to add between language tags and input sentence (default: " ").
|
| |
|
| | Returns:
|
| | str: input sentence with the special tokens added to the start.
|
| | """
|
| | return src_lang + delimiter + tgt_lang + delimiter + sent
|
| |
|
| |
|
| | def apply_lang_tags(sents: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
| | """
|
| | Add special tokens indicating source and target language to the start of the each input sentence.
|
| | Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
|
| |
|
| | Args:
|
| | sent (str): input sentence to be translated.
|
| | src_lang (str): flores lang code of the input sentence.
|
| | tgt_lang (str): flores lang code in which the input sentence will be translated.
|
| |
|
| | Returns:
|
| | List[str]: list of input sentences with the special tokens added to the start.
|
| | """
|
| | tagged_sents = []
|
| | for sent in sents:
|
| | tagged_sent = add_token(sent.strip(), src_lang, tgt_lang)
|
| | tagged_sents.append(tagged_sent)
|
| | return tagged_sents
|
| |
|
| |
|
| | def truncate_long_sentences(
|
| | sents: List[str], placeholder_entity_map_sents: List[Dict]
|
| | ) -> Tuple[List[str], List[Dict]]:
|
| | """
|
| | Truncates the sentences that exceed the maximum sequence length.
|
| | The maximum sequence for the IndicTrans2 model is limited to 256 tokens.
|
| |
|
| | Args:
|
| | sents (List[str]): list of input sentences to truncate.
|
| |
|
| | Returns:
|
| | Tuple[List[str], List[Dict]]: tuple containing the list of sentences with truncation applied and the updated placeholder entity maps.
|
| | """
|
| | MAX_SEQ_LEN = 256
|
| | new_sents = []
|
| | placeholders = []
|
| |
|
| | for j, sent in enumerate(sents):
|
| | words = sent.split()
|
| | num_words = len(words)
|
| | if num_words > MAX_SEQ_LEN:
|
| | sents = []
|
| | i = 0
|
| | while i <= len(words):
|
| | sents.append(" ".join(words[i : i + MAX_SEQ_LEN]))
|
| | i += MAX_SEQ_LEN
|
| | placeholders.extend([placeholder_entity_map_sents[j]] * (len(sents)))
|
| | new_sents.extend(sents)
|
| | else:
|
| | placeholders.append(placeholder_entity_map_sents[j])
|
| | new_sents.append(sent)
|
| | return new_sents, placeholders
|
| |
|
| |
|
| | class Model:
|
| | """
|
| | Model class to run the IndicTransv2 models using python interface.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | ckpt_dir: str,
|
| | device: str = "cuda",
|
| | input_lang_code_format: str = "flores",
|
| | model_type: str = "ctranslate2",
|
| | ):
|
| | """
|
| | Initialize the model class.
|
| |
|
| | Args:
|
| | ckpt_dir (str): path of the model checkpoint directory.
|
| | device (str, optional): where to load the model (defaults: cuda).
|
| | """
|
| | self.ckpt_dir = ckpt_dir
|
| | self.en_tok = MosesTokenizer(lang="en")
|
| | self.en_normalizer = MosesPunctNormalizer()
|
| | self.en_detok = MosesDetokenizer(lang="en")
|
| | self.xliterator = unicode_transliterate.UnicodeIndicTransliterator()
|
| |
|
| | print("Initializing sentencepiece model for SRC and TGT")
|
| | self.sp_src = spm.SentencePieceProcessor(
|
| | model_file=os.path.join(ckpt_dir, "vocab", "model.SRC")
|
| | )
|
| | self.sp_tgt = spm.SentencePieceProcessor(
|
| | model_file=os.path.join(ckpt_dir, "vocab", "model.TGT")
|
| | )
|
| |
|
| | self.input_lang_code_format = input_lang_code_format
|
| |
|
| | print("Initializing model for translation")
|
| |
|
| | if model_type == "ctranslate2":
|
| | import ctranslate2
|
| |
|
| | self.translator = ctranslate2.Translator(
|
| | self.ckpt_dir, device=device
|
| | )
|
| | self.translate_lines = self.ctranslate2_translate_lines
|
| | elif model_type == "fairseq":
|
| | from .custom_interactive import Translator
|
| |
|
| | self.translator = Translator(
|
| | data_dir=os.path.join(self.ckpt_dir, "final_bin"),
|
| | checkpoint_path=os.path.join(self.ckpt_dir, "model", "checkpoint_best.pt"),
|
| | batch_size=100,
|
| | )
|
| | self.translate_lines = self.fairseq_translate_lines
|
| | else:
|
| | raise NotImplementedError(f"Unknown model_type: {model_type}")
|
| |
|
| | def ctranslate2_translate_lines(self, lines: List[str]) -> List[str]:
|
| | tokenized_sents = [x.strip().split(" ") for x in lines]
|
| | translations = self.translator.translate_batch(
|
| | tokenized_sents,
|
| | max_batch_size=9216,
|
| | batch_type="tokens",
|
| | max_input_length=160,
|
| | max_decoding_length=256,
|
| | beam_size=5,
|
| | )
|
| | translations = [" ".join(x.hypotheses[0]) for x in translations]
|
| | return translations
|
| |
|
| | def fairseq_translate_lines(self, lines: List[str]) -> List[str]:
|
| | return self.translator.translate(lines)
|
| |
|
| | def paragraphs_batch_translate__multilingual(self, batch_payloads: List[tuple]) -> List[str]:
|
| | """
|
| | Translates a batch of input paragraphs (including pre/post processing)
|
| | from any language to any language.
|
| |
|
| | Args:
|
| | batch_payloads (List[tuple]): batch of long input-texts to be translated, each in format: (paragraph, src_lang, tgt_lang)
|
| |
|
| | Returns:
|
| | List[str]: batch of paragraph-translations in the respective languages.
|
| | """
|
| | paragraph_id_to_sentence_range = []
|
| | global__sents = []
|
| | global__preprocessed_sents = []
|
| | global__preprocessed_sents_placeholder_entity_map = []
|
| |
|
| | for i in range(len(batch_payloads)):
|
| | paragraph, src_lang, tgt_lang = batch_payloads[i]
|
| | if self.input_lang_code_format == "iso":
|
| | src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang]
|
| |
|
| | batch = split_sentences(paragraph, src_lang)
|
| | global__sents.extend(batch)
|
| |
|
| | preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch(
|
| | batch, src_lang, tgt_lang
|
| | )
|
| |
|
| | global_sentence_start_index = len(global__preprocessed_sents)
|
| | global__preprocessed_sents.extend(preprocessed_sents)
|
| | global__preprocessed_sents_placeholder_entity_map.extend(placeholder_entity_map_sents)
|
| | paragraph_id_to_sentence_range.append(
|
| | (global_sentence_start_index, len(global__preprocessed_sents))
|
| | )
|
| |
|
| | translations = self.translate_lines(global__preprocessed_sents)
|
| |
|
| | translated_paragraphs = []
|
| | for paragraph_id, sentence_range in enumerate(paragraph_id_to_sentence_range):
|
| | tgt_lang = batch_payloads[paragraph_id][2]
|
| | if self.input_lang_code_format == "iso":
|
| | tgt_lang = iso_to_flores[tgt_lang]
|
| |
|
| | postprocessed_sents = self.postprocess(
|
| | translations[sentence_range[0] : sentence_range[1]],
|
| | global__preprocessed_sents_placeholder_entity_map[
|
| | sentence_range[0] : sentence_range[1]
|
| | ],
|
| | tgt_lang,
|
| | )
|
| | translated_paragraph = " ".join(postprocessed_sents)
|
| | translated_paragraphs.append(translated_paragraph)
|
| |
|
| | return translated_paragraphs
|
| |
|
| |
|
| | def batch_translate(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
| | """
|
| | Translates a batch of input sentences (including pre/post processing)
|
| | from source language to target language.
|
| |
|
| | Args:
|
| | batch (List[str]): batch of input sentences to be translated.
|
| | src_lang (str): flores source language code.
|
| | tgt_lang (str): flores target language code.
|
| |
|
| | Returns:
|
| | List[str]: batch of translated-sentences generated by the model.
|
| | """
|
| |
|
| | assert isinstance(batch, list)
|
| |
|
| | if self.input_lang_code_format == "iso":
|
| | src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang]
|
| |
|
| | preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch(
|
| | batch, src_lang, tgt_lang
|
| | )
|
| | translations = self.translate_lines(preprocessed_sents)
|
| | return self.postprocess(translations, placeholder_entity_map_sents, tgt_lang)
|
| |
|
| |
|
| | def translate_paragraph(self, paragraph: str, src_lang: str, tgt_lang: str) -> str:
|
| | """
|
| | Translates an input text paragraph (including pre/post processing)
|
| | from source language to target language.
|
| |
|
| | Args:
|
| | paragraph (str): input text paragraph to be translated.
|
| | src_lang (str): flores source language code.
|
| | tgt_lang (str): flores target language code.
|
| |
|
| | Returns:
|
| | str: paragraph translation generated by the model.
|
| | """
|
| |
|
| | assert isinstance(paragraph, str)
|
| |
|
| | if self.input_lang_code_format == "iso":
|
| | flores_src_lang = iso_to_flores[src_lang]
|
| | else:
|
| | flores_src_lang = src_lang
|
| |
|
| | sents = split_sentences(paragraph, flores_src_lang)
|
| | postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang)
|
| | translated_paragraph = " ".join(postprocessed_sents)
|
| |
|
| | return translated_paragraph
|
| |
|
| | def preprocess_batch(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
| | """
|
| | Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the
|
| | normalized text sequences using sentence piece tokenizer and also adds language tags.
|
| |
|
| | Args:
|
| | batch (List[str]): input list of sentences to preprocess.
|
| | src_lang (str): flores language code of the input text sentences.
|
| | tgt_lang (str): flores language code of the output text sentences.
|
| |
|
| | Returns:
|
| | Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
|
| | mapping placeholders to their original values.
|
| | """
|
| | preprocessed_sents, placeholder_entity_map_sents = self.preprocess(batch, lang=src_lang)
|
| | tokenized_sents = self.apply_spm(preprocessed_sents)
|
| | tokenized_sents, placeholder_entity_map_sents = truncate_long_sentences(
|
| | tokenized_sents, placeholder_entity_map_sents
|
| | )
|
| | tagged_sents = apply_lang_tags(tokenized_sents, src_lang, tgt_lang)
|
| | return tagged_sents, placeholder_entity_map_sents
|
| |
|
| | def apply_spm(self, sents: List[str]) -> List[str]:
|
| | """
|
| | Applies sentence piece encoding to the batch of input sentences.
|
| |
|
| | Args:
|
| | sents (List[str]): batch of the input sentences.
|
| |
|
| | Returns:
|
| | List[str]: batch of encoded sentences with sentence piece model
|
| | """
|
| | return [" ".join(self.sp_src.encode(sent, out_type=str)) for sent in sents]
|
| |
|
| | def preprocess_sent(
|
| | self,
|
| | sent: str,
|
| | normalizer: Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory],
|
| | lang: str,
|
| | ) -> Tuple[str, Dict]:
|
| | """
|
| | Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it.
|
| |
|
| | Args:
|
| | sent (str): input text sentence to preprocess.
|
| | normalizer (Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory]): an object that performs normalization on the text.
|
| | lang (str): flores language code of the input text sentence.
|
| |
|
| | Returns:
|
| | Tuple[str, Dict]: A tuple containing the preprocessed input text sentence and a corresponding dictionary
|
| | mapping placeholders to their original values.
|
| | """
|
| | iso_lang = flores_codes[lang]
|
| | sent = punc_norm(sent, iso_lang)
|
| | sent, placeholder_entity_map = normalize(sent)
|
| |
|
| | transliterate = True
|
| | if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]:
|
| | transliterate = False
|
| |
|
| | if iso_lang == "en":
|
| | processed_sent = " ".join(
|
| | self.en_tok.tokenize(self.en_normalizer.normalize(sent.strip()), escape=False)
|
| | )
|
| | elif transliterate:
|
| |
|
| |
|
| | processed_sent = self.xliterator.transliterate(
|
| | " ".join(
|
| | indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang)
|
| | ),
|
| | iso_lang,
|
| | "hi",
|
| | ).replace(" ् ", "्")
|
| | else:
|
| |
|
| | processed_sent = " ".join(
|
| | indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang)
|
| | )
|
| |
|
| | return processed_sent, placeholder_entity_map
|
| |
|
| | def preprocess(self, sents: List[str], lang: str):
|
| | """
|
| | Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it.
|
| |
|
| | Args:
|
| | batch (List[str]): input list of sentences to preprocess.
|
| | lang (str): flores language code of the input text sentences.
|
| |
|
| | Returns:
|
| | Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
|
| | mapping placeholders to their original values.
|
| | """
|
| | processed_sents, placeholder_entity_map_sents = [], []
|
| |
|
| | if lang == "eng_Latn":
|
| | normalizer = None
|
| | else:
|
| | normfactory = indic_normalize.IndicNormalizerFactory()
|
| | normalizer = normfactory.get_normalizer(flores_codes[lang])
|
| |
|
| | for sent in sents:
|
| | sent, placeholder_entity_map = self.preprocess_sent(sent, normalizer, lang)
|
| | processed_sents.append(sent)
|
| | placeholder_entity_map_sents.append(placeholder_entity_map)
|
| |
|
| | return processed_sents, placeholder_entity_map_sents
|
| |
|
| | def postprocess(
|
| | self,
|
| | sents: List[str],
|
| | placeholder_entity_map: List[Dict],
|
| | lang: str,
|
| | common_lang: str = "hin_Deva",
|
| | ) -> List[str]:
|
| | """
|
| | Postprocesses a batch of input sentences after the translation generations.
|
| |
|
| | Args:
|
| | sents (List[str]): batch of translated sentences to postprocess.
|
| | placeholder_entity_map (List[Dict]): dictionary mapping placeholders to the original entity values.
|
| | lang (str): flores language code of the input sentences.
|
| | common_lang (str, optional): flores language code of the transliterated language (defaults: hin_Deva).
|
| |
|
| | Returns:
|
| | List[str]: postprocessed batch of input sentences.
|
| | """
|
| |
|
| | lang_code, script_code = lang.split("_")
|
| |
|
| | for i in range(len(sents)):
|
| |
|
| |
|
| |
|
| | sents[i] = sents[i].replace(" ", "").replace("▁", " ").strip()
|
| |
|
| |
|
| |
|
| | if script_code in {"Arab", "Aran"}:
|
| |
|
| | sents[i] = sents[i].replace(" ؟", "؟").replace(" ۔", "۔").replace(" ،", "،")
|
| |
|
| | sents[i] = sents[i].replace("ٮ۪", "ؠ")
|
| |
|
| | assert len(sents) == len(placeholder_entity_map)
|
| |
|
| | for i in range(0, len(sents)):
|
| | for key in placeholder_entity_map[i].keys():
|
| | sents[i] = sents[i].replace(key, placeholder_entity_map[i][key])
|
| |
|
| |
|
| | postprocessed_sents = []
|
| |
|
| | if lang == "eng_Latn":
|
| | for sent in sents:
|
| | postprocessed_sents.append(self.en_detok.detokenize(sent.split(" ")))
|
| | else:
|
| | for sent in sents:
|
| | outstr = indic_detokenize.trivial_detokenize(
|
| | self.xliterator.transliterate(
|
| | sent, flores_codes[common_lang], flores_codes[lang]
|
| | ),
|
| | flores_codes[lang],
|
| | )
|
| |
|
| |
|
| |
|
| | if lang_code == "ory":
|
| | outstr = outstr.replace("ଯ଼", 'ୟ')
|
| |
|
| | postprocessed_sents.append(outstr)
|
| |
|
| | return postprocessed_sents
|
| |
|