| from onmt.utils.logging import logger |
| from onmt.transforms import register_transform |
| from .transform import Transform |
|
|
| import spacy |
| import ahocorasick |
| import re |
|
|
|
|
| class TermMatcher(object): |
| def __init__( |
| self, |
| termbase_path, |
| src_spacy_language_model, |
| tgt_spacy_language_model, |
| term_example_ratio, |
| src_term_stoken, |
| tgt_term_stoken, |
| tgt_term_etoken, |
| delimiter, |
| term_corpus_ratio=0.2, |
| ): |
| self.term_example_ratio = term_example_ratio |
| self.src_nlp = spacy.load(src_spacy_language_model, disable=["parser", "ner"]) |
| self.tgt_nlp = spacy.load(tgt_spacy_language_model, disable=["parser", "ner"]) |
|
|
| |
| |
| |
| self.src_nlp.tokenizer.rules = { |
| key: value |
| for key, value in self.src_nlp.tokenizer.rules.items() |
| if "'" not in key and "’" not in key and "‘" not in key |
| } |
| self.tgt_nlp.tokenizer.rules = { |
| key: value |
| for key, value in self.tgt_nlp.tokenizer.rules.items() |
| if "'" not in key and "’" not in key and "‘" not in key |
| } |
| self.internal_termbase = self._create_internal_termbase(termbase_path) |
| self.automaton = self._create_automaton() |
| self.term_corpus_ratio = term_corpus_ratio |
| self.src_term_stoken = src_term_stoken |
| self.tgt_term_stoken = tgt_term_stoken |
| self.tgt_term_etoken = tgt_term_etoken |
| self.delimiter = delimiter |
|
|
| def _create_internal_termbase(self, termbase_path): |
| logger.debug("Creating termbase with lemmas for Terminology transform") |
|
|
| |
| src_stopwords = self.src_nlp.Defaults.stop_words |
| tgt_stopwords = self.tgt_nlp.Defaults.stop_words |
| termbase = list() |
| with open(termbase_path, mode="r", encoding="utf-8") as file: |
| pairs = file.readlines() |
| for pair in pairs: |
| src_term, tgt_term = map(str, pair.split("\t")) |
| src_lemma = " ".join( |
| "∥".join(tok.lemma_.split()) for tok in self.src_nlp(src_term) |
| ).strip() |
| tgt_lemma = " ".join( |
| tok.lemma_ for tok in self.tgt_nlp(tgt_term) |
| ).strip() |
| if ( |
| src_lemma.lower() not in src_stopwords |
| and tgt_lemma.lower() not in tgt_stopwords |
| ): |
| termbase.append((src_lemma, tgt_lemma)) |
| logger.debug( |
| f"Created termbase with {len(termbase)} lemmas " |
| f"for Terminology transform" |
| ) |
| return termbase |
|
|
| def _create_automaton(self): |
| automaton = ahocorasick.Automaton() |
| for term in self.internal_termbase: |
| automaton.add_word(term[0], (term[0], term[1])) |
| automaton.make_automaton() |
| return automaton |
|
|
| def _src_sentence_with_terms(self, source_string, target_string) -> tuple: |
|
|
| maybe_augmented = source_string.split(self.delimiter) |
| source_only = maybe_augmented[0].strip() |
| augmented_part = ( |
| maybe_augmented[1].strip() if len(maybe_augmented) > 1 else None |
| ) |
|
|
| doc_src = self.src_nlp(source_only) |
| doc_tgt = self.tgt_nlp(target_string) |
|
|
| |
| tokenized_source = [tok.text for tok in doc_src] |
| lemmatized_source = ["∥".join(tok.lemma_.lower().split()) for tok in doc_src] |
| lemmatized_target = [tok.lemma_.lower() for tok in doc_tgt] |
|
|
| lemmatized_source_string = " ".join(lemmatized_source) |
|
|
| offset = 0 |
| source_with_terms = list() |
| term_counter = 0 |
|
|
| max_terms_allowed = int(len(tokenized_source) * self.term_example_ratio) |
| is_match = False |
| for match_end, (src_entry, tgt_entry) in self.automaton.iter_long( |
| lemmatized_source_string |
| ): |
|
|
| if term_counter == max_terms_allowed: |
| break |
|
|
| match_start = match_end - len(src_entry) + 1 |
|
|
| |
| |
| |
| |
| if ( |
| (tgt_entry.lower() not in " ".join(lemmatized_target).lower()) |
| or ( |
| len(lemmatized_source_string) != match_end + 1 |
| and not (lemmatized_source_string[match_end + 1].isspace()) |
| ) |
| or ( |
| not lemmatized_source_string[match_start - 1].isspace() |
| and match_start != 0 |
| ) |
| ): |
| continue |
| else: |
| term_counter += 1 |
|
|
| |
| |
| lemma_list_index = 0 |
| for i, w in enumerate(lemmatized_source): |
| if lemma_list_index == match_start: |
| lemma_list_index = i |
| break |
| else: |
| lemma_list_index += len(w) + 1 |
|
|
| |
| num_words_in_src_term = len(src_entry.split()) |
| src_term = " ".join( |
| tokenized_source[ |
| lemma_list_index : lemma_list_index + num_words_in_src_term |
| ] |
| ).strip() |
|
|
| |
| |
| tgt_term = tgt_entry.replace(" ", "∥").rstrip().lower() |
| source_with_terms.append( |
| f"{lemmatized_source_string[offset: match_start]}" |
| f"{self.src_term_stoken}∥{src_term}∥{self.tgt_term_stoken}∥" |
| f"{tgt_term}∥{self.tgt_term_etoken}" |
| ) |
|
|
| offset = match_end + 1 |
| is_match = True |
|
|
| if is_match: |
| source_with_terms.append(lemmatized_source_string[offset:]) |
| tokenized_source_with_terms = "".join(source_with_terms).split() |
|
|
| if not ( |
| len(tokenized_source) |
| == len(lemmatized_source) |
| == len(tokenized_source_with_terms) |
| ): |
| final_string = " ".join(tokenized_source) |
| fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) |
| return fixed_punct.split(), not is_match |
|
|
| |
| |
| |
| |
| |
| |
| |
| completed_tokenized_source = list() |
| for idx in range(len(tokenized_source_with_terms)): |
| |
| src_lemma = tokenized_source_with_terms[idx].replace("∥", " ") |
| if lemmatized_source[idx].replace("∥", " ") == src_lemma: |
| completed_tokenized_source.append(tokenized_source[idx]) |
| else: |
| completed_tokenized_source.append(src_lemma) |
|
|
| if augmented_part is not None: |
| final_string = " ".join( |
| completed_tokenized_source |
| + [self.delimiter] |
| + augmented_part.split() |
| ) |
| else: |
| final_string = " ".join(completed_tokenized_source) |
|
|
| fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) |
| return fixed_punct.split(), is_match |
| else: |
| final_string = " ".join(tokenized_source) |
| fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) |
| return fixed_punct.split(), not is_match |
|
|
|
|
| @register_transform(name="terminology") |
| class TerminologyTransform(Transform): |
| def __init__(self, opts): |
| super().__init__(opts) |
|
|
| @classmethod |
| def add_options(cls, parser): |
| """Available options for terminology matching.""" |
|
|
| group = parser.add_argument_group("Transform/Terminology") |
| group.add( |
| "--termbase_path", |
| "-termbase_path", |
| type=str, |
| help="Path to a dictionary file with terms.", |
| ) |
| group.add( |
| "--src_spacy_language_model", |
| "-src_spacy_language_model", |
| type=str, |
| help="Name of the spacy language model for the source corpus.", |
| ) |
| group.add( |
| "--tgt_spacy_language_model", |
| "-tgt_spacy_language_model", |
| type=str, |
| help="Name of the spacy language model for the target corpus.", |
| ) |
| group.add( |
| "--term_corpus_ratio", |
| "-term_corpus_ratio", |
| type=float, |
| default=0.3, |
| help="Ratio of corpus to augment with terms.", |
| ) |
| group.add( |
| "--term_example_ratio", |
| "-term_example_ratio", |
| type=float, |
| default=0.2, |
| help="Max terms allowed in an example.", |
| ) |
| group.add( |
| "--src_term_stoken", |
| "-src_term_stoken", |
| type=str, |
| help="The source term start token.", |
| default="⦅src_term_start⦆", |
| ) |
| group.add( |
| "--tgt_term_stoken", |
| "-tgt_term_stoken", |
| type=str, |
| help="The target term start token.", |
| default="⦅tgt_term_start⦆", |
| ) |
| group.add( |
| "--tgt_term_etoken", |
| "-tgt_term_etoken", |
| type=str, |
| help="The target term end token.", |
| default="⦅tgt_term_end⦆", |
| ) |
| group.add( |
| "--term_source_delimiter", |
| "-term_source_delimiter", |
| type=str, |
| help="Any special token used for augmented source sentences. " |
| "The default is the fuzzy token used in the " |
| "FuzzyMatch transform.", |
| default="⦅fuzzy⦆", |
| ) |
|
|
| def _parse_opts(self): |
| self.termbase_path = self.opts.termbase_path |
| self.src_spacy_language_model = self.opts.src_spacy_language_model |
| self.tgt_spacy_language_model = self.opts.tgt_spacy_language_model |
| self.term_corpus_ratio = self.opts.term_corpus_ratio |
| self.term_example_ratio = self.opts.term_example_ratio |
| self.term_source_delimiter = self.opts.term_source_delimiter |
| self.src_term_stoken = self.opts.src_term_stoken |
| self.tgt_term_stoken = self.opts.tgt_term_stoken |
| self.tgt_term_etoken = self.opts.tgt_term_etoken |
|
|
| @classmethod |
| def get_specials(cls, opts): |
| """Add the term tokens to the src vocab.""" |
| src_specials = list() |
| src_specials.extend( |
| [opts.src_term_stoken, opts.tgt_term_stoken, opts.tgt_term_etoken] |
| ) |
| return (src_specials, list()) |
|
|
| def warm_up(self, vocabs=None): |
| """Create the terminology matcher.""" |
|
|
| super().warm_up(None) |
| self.termmatcher = TermMatcher( |
| self.termbase_path, |
| self.src_spacy_language_model, |
| self.tgt_spacy_language_model, |
| self.term_example_ratio, |
| self.src_term_stoken, |
| self.tgt_term_stoken, |
| self.tgt_term_etoken, |
| self.term_source_delimiter, |
| self.term_corpus_ratio, |
| ) |
|
|
| def batch_apply(self, batch, is_train=False, stats=None, **kwargs): |
| bucket_size = len(batch) |
| examples_with_terms = 0 |
|
|
| for i, (ex, _, _) in enumerate(batch): |
| |
| |
| |
| if i % 2 == 0: |
| original_src = ex["src"] |
| augmented_example, is_match = self.apply(ex, is_train, stats, **kwargs) |
| if is_match and ( |
| examples_with_terms < bucket_size * self.term_corpus_ratio |
| ): |
| examples_with_terms += 1 |
| ex["src"] = augmented_example["src"] |
| else: |
| ex["src"] = original_src |
|
|
| logger.debug(f"Added terms to {examples_with_terms}/{bucket_size} examples") |
| return batch |
|
|
| def apply(self, example, is_train=False, stats=None, **kwargs) -> tuple: |
| """Add terms to source examples.""" |
|
|
| example["src"], is_match = self.termmatcher._src_sentence_with_terms( |
| " ".join(example["src"]), " ".join(example["tgt"]) |
| ) |
| return example, is_match |
|
|