| import logging | |
| from typing import Iterable, Iterator, List, Union | |
| import chemdataextractor | |
| import sentencepiece as spm | |
| from chemdataextractor.data import Package | |
| from rxn.onmt_utils.internal_translation_utils import TranslationResult | |
| from rxn.onmt_utils.translator import Translator | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| def download_cde_data() -> None: | |
| package = Package("models/punkt_chem-1.0.pickle") | |
| if package.local_exists(): | |
| return | |
| logger.info("Downloading the necessary ChemDataExtractor data...") | |
| package.download() | |
| logger.info("Downloading the necessary ChemDataExtractor data... Done.") | |
| def split_into_sentences(text: str) -> List[str]: | |
| paragraph = chemdataextractor.doc.Paragraph(text) | |
| return [sentence.text for sentence in paragraph.sentences] | |
| class SentencePieceTokenizer: | |
| def __init__(self, model_file: str): | |
| self.sp = spm.SentencePieceProcessor() | |
| self.sp.Load(model_file) | |
| def tokenize(self, sentence: str) -> str: | |
| tokens = self.sp.EncodeAsPieces(sentence) | |
| tokenized = " ".join(tokens) | |
| return tokenized | |
| def detokenize(self, sentence: str) -> str: | |
| tokens = sentence.split(" ") | |
| detokenized = self.sp.DecodePieces(tokens) | |
| return detokenized | |
| class TranslatorWithSentencePiece: | |
| def __init__( | |
| self, translation_model: Union[str, Iterable[str]], sentencepiece_model: str | |
| ): | |
| self.sp = SentencePieceTokenizer(sentencepiece_model) | |
| self.translator = Translator.from_model_path(translation_model) | |
| def translate(self, sentences: List[str]) -> List[str]: | |
| translations = self.translate_multiple_with_scores(sentences) | |
| return [t[0].text for t in translations] | |
| def translate_multiple_with_scores( | |
| self, sentences: List[str], n_best=1 | |
| ) -> Iterator[List[TranslationResult]]: | |
| tokenized_sentences = [self.sp.tokenize(s) for s in sentences] | |
| translations = self.translator.translate_multiple_with_scores( | |
| tokenized_sentences, n_best | |
| ) | |
| for translation_group in translations: | |
| for t in translation_group: | |
| t.text = self.sp.detokenize(t.text) | |
| yield translation_group | |