| import argparse |
| import os |
| from difflib import SequenceMatcher |
|
|
| import Levenshtein |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from helpers import write_lines, read_parallel_lines, encode_verb_form, \ |
| apply_reverse_transformation, SEQ_DELIMETERS, START_TOKEN |
|
|
|
|
| def perfect_align(t, T, insertions_allowed=0, |
| cost_function=Levenshtein.distance): |
| |
| |
| |
|
|
| |
| shape = (len(t) + 1, len(T) + 1, insertions_allowed + 1) |
| dp = np.ones(shape, dtype=int) * int(1e9) |
| come_from = np.ones(shape, dtype=int) * int(1e9) |
| come_from_ins = np.ones(shape, dtype=int) * int(1e9) |
|
|
| dp[0, 0, 0] = 0 |
| for i in range(len(t) + 1): |
| for j in range(len(T) + 1): |
| for q in range(insertions_allowed + 1): |
| if i < len(t): |
| |
| |
| for k in range(j, len(T) + 1): |
| transform = \ |
| apply_transformation(t[i], ' '.join(T[j:k])) |
| if transform: |
| cost = 0 |
| else: |
| cost = cost_function(t[i], ' '.join(T[j:k])) |
| current = dp[i, j, q] + cost |
| if dp[i + 1, k, 0] > current: |
| dp[i + 1, k, 0] = current |
| come_from[i + 1, k, 0] = j |
| come_from_ins[i + 1, k, 0] = q |
| if q < insertions_allowed: |
| |
| |
| for k in range(j, len(T) + 1): |
| cost = len(' '.join(T[j:k])) |
| current = dp[i, j, q] + cost |
| if dp[i, k, q + 1] > current: |
| dp[i, k, q + 1] = current |
| come_from[i, k, q + 1] = j |
| come_from_ins[i, k, q + 1] = q |
|
|
| |
| alignment = [] |
| i = len(t) |
| j = len(T) |
| q = dp[i, j, :].argmin() |
| while i > 0 or q > 0: |
| is_insert = (come_from_ins[i, j, q] != q) and (q != 0) |
| j, k, q = come_from[i, j, q], j, come_from_ins[i, j, q] |
| if not is_insert: |
| i -= 1 |
|
|
| if is_insert: |
| alignment.append(['INSERT', T[j:k], (i, i)]) |
| else: |
| alignment.append([f'REPLACE_{t[i]}', T[j:k], (i, i + 1)]) |
|
|
| assert j == 0 |
|
|
| return dp[len(t), len(T)].min(), list(reversed(alignment)) |
|
|
|
|
| def _split(token): |
| if not token: |
| return [] |
| parts = token.split() |
| return parts or [token] |
|
|
|
|
| def apply_merge_transformation(source_tokens, target_words, shift_idx): |
| edits = [] |
| if len(source_tokens) > 1 and len(target_words) == 1: |
| |
| transform = check_merge(source_tokens, target_words) |
| if transform: |
| for i in range(len(source_tokens) - 1): |
| edits.append([(shift_idx + i, shift_idx + i + 1), transform]) |
| return edits |
|
|
| if len(source_tokens) == len(target_words) == 2: |
| |
| transform = check_swap(source_tokens, target_words) |
| if transform: |
| edits.append([(shift_idx, shift_idx + 1), transform]) |
| return edits |
|
|
|
|
| def is_sent_ok(sent, delimeters=SEQ_DELIMETERS): |
| for del_val in delimeters.values(): |
| if del_val in sent and del_val != delimeters["tokens"]: |
| return False |
| return True |
|
|
|
|
| def check_casetype(source_token, target_token): |
| if source_token.lower() != target_token.lower(): |
| return None |
| if source_token.lower() == target_token: |
| return "$TRANSFORM_CASE_LOWER" |
| elif source_token.capitalize() == target_token: |
| return "$TRANSFORM_CASE_CAPITAL" |
| elif source_token.upper() == target_token: |
| return "$TRANSFORM_CASE_UPPER" |
| elif source_token[1:].capitalize() == target_token[1:] and source_token[0] == target_token[0]: |
| return "$TRANSFORM_CASE_CAPITAL_1" |
| elif source_token[:-1].upper() == target_token[:-1] and source_token[-1] == target_token[-1]: |
| return "$TRANSFORM_CASE_UPPER_-1" |
| else: |
| return None |
|
|
|
|
| def check_equal(source_token, target_token): |
| if source_token == target_token: |
| return "$KEEP" |
| else: |
| return None |
|
|
|
|
| def check_split(source_token, target_tokens): |
| if source_token.split("-") == target_tokens: |
| return "$TRANSFORM_SPLIT_HYPHEN" |
| else: |
| return None |
|
|
|
|
| def check_merge(source_tokens, target_tokens): |
| if "".join(source_tokens) == "".join(target_tokens): |
| return "$MERGE_SPACE" |
| elif "-".join(source_tokens) == "-".join(target_tokens): |
| return "$MERGE_HYPHEN" |
| else: |
| return None |
|
|
|
|
| def check_swap(source_tokens, target_tokens): |
| if source_tokens == [x for x in reversed(target_tokens)]: |
| return "$MERGE_SWAP" |
| else: |
| return None |
|
|
|
|
| def check_plural(source_token, target_token): |
| if source_token.endswith("s") and source_token[:-1] == target_token: |
| return "$TRANSFORM_AGREEMENT_SINGULAR" |
| elif target_token.endswith("s") and source_token == target_token[:-1]: |
| return "$TRANSFORM_AGREEMENT_PLURAL" |
| else: |
| return None |
|
|
|
|
| def check_verb(source_token, target_token): |
| encoding = encode_verb_form(source_token, target_token) |
| if encoding: |
| return f"$TRANSFORM_VERB_{encoding}" |
| else: |
| return None |
|
|
|
|
| def apply_transformation(source_token, target_token): |
| target_tokens = target_token.split() |
| if len(target_tokens) > 1: |
| |
| transform = check_split(source_token, target_tokens) |
| if transform: |
| return transform |
| checks = [check_equal, check_casetype, check_verb, check_plural] |
| for check in checks: |
| transform = check(source_token, target_token) |
| if transform: |
| return transform |
| return None |
|
|
|
|
| def align_sequences(source_sent, target_sent): |
| |
| if not is_sent_ok(source_sent) or not is_sent_ok(target_sent): |
| return None |
| source_tokens = source_sent.split() |
| target_tokens = target_sent.split() |
| matcher = SequenceMatcher(None, source_tokens, target_tokens) |
| diffs = list(matcher.get_opcodes()) |
| all_edits = [] |
| for diff in diffs: |
| tag, i1, i2, j1, j2 = diff |
| source_part = _split(" ".join(source_tokens[i1:i2])) |
| target_part = _split(" ".join(target_tokens[j1:j2])) |
| if tag == 'equal': |
| continue |
| elif tag == 'delete': |
| |
| for j in range(i2 - i1): |
| edit = [(i1 + j, i1 + j + 1), '$DELETE'] |
| all_edits.append(edit) |
| elif tag == 'insert': |
| |
| for target_token in target_part: |
| edit = ((i1 - 1, i1), f"$APPEND_{target_token}") |
| all_edits.append(edit) |
| else: |
| |
| edits = apply_merge_transformation(source_part, target_part, |
| shift_idx=i1) |
| if edits: |
| all_edits.extend(edits) |
| continue |
|
|
| |
| _, alignments = perfect_align(source_part, target_part, |
| insertions_allowed=0) |
| for alignment in alignments: |
| new_shift = alignment[2][0] |
| edits = convert_alignments_into_edits(alignment, |
| shift_idx=i1 + new_shift) |
| all_edits.extend(edits) |
|
|
| |
| labels = convert_edits_into_labels(source_tokens, all_edits) |
| |
| sent_with_tags = add_labels_to_the_tokens(source_tokens, labels) |
| return sent_with_tags |
|
|
|
|
| def convert_edits_into_labels(source_tokens, all_edits): |
| |
| flat_edits = [] |
| for edit in all_edits: |
| (start, end), edit_operations = edit |
| if isinstance(edit_operations, list): |
| for operation in edit_operations: |
| new_edit = [(start, end), operation] |
| flat_edits.append(new_edit) |
| elif isinstance(edit_operations, str): |
| flat_edits.append(edit) |
| else: |
| raise Exception("Unknown operation type") |
| all_edits = flat_edits[:] |
| labels = [] |
| total_labels = len(source_tokens) + 1 |
| if not all_edits: |
| labels = [["$KEEP"] for x in range(total_labels)] |
| else: |
| for i in range(total_labels): |
| edit_operations = [x[1] for x in all_edits if x[0][0] == i - 1 |
| and x[0][1] == i] |
| if not edit_operations: |
| labels.append(["$KEEP"]) |
| else: |
| labels.append(edit_operations) |
| return labels |
|
|
|
|
| def convert_alignments_into_edits(alignment, shift_idx): |
| edits = [] |
| action, target_tokens, new_idx = alignment |
| source_token = action.replace("REPLACE_", "") |
|
|
| |
| if not target_tokens: |
| edit = [(shift_idx, 1 + shift_idx), "$DELETE"] |
| return [edit] |
|
|
| |
| for i in range(1, len(target_tokens)): |
| target_token = " ".join(target_tokens[:i + 1]) |
| transform = apply_transformation(source_token, target_token) |
| if transform: |
| edit = [(shift_idx, shift_idx + 1), transform] |
| edits.append(edit) |
| target_tokens = target_tokens[i + 1:] |
| for target in target_tokens: |
| edits.append([(shift_idx, shift_idx + 1), f"$APPEND_{target}"]) |
| return edits |
|
|
| transform_costs = [] |
| transforms = [] |
| for target_token in target_tokens: |
| transform = apply_transformation(source_token, target_token) |
| if transform: |
| cost = 0 |
| transforms.append(transform) |
| else: |
| cost = Levenshtein.distance(source_token, target_token) |
| transforms.append(None) |
| transform_costs.append(cost) |
| min_cost_idx = transform_costs.index(min(transform_costs)) |
| |
| for i in range(0, min_cost_idx): |
| target = target_tokens[i] |
| edit = [(shift_idx - 1, shift_idx), f"$APPEND_{target}"] |
| edits.append(edit) |
| |
| transform = transforms[min_cost_idx] |
| target = transform if transform is not None \ |
| else f"$REPLACE_{target_tokens[min_cost_idx]}" |
| edit = [(shift_idx, 1 + shift_idx), target] |
| edits.append(edit) |
| |
| for i in range(min_cost_idx + 1, len(target_tokens)): |
| target = target_tokens[i] |
| edit = [(shift_idx, 1 + shift_idx), f"$APPEND_{target}"] |
| edits.append(edit) |
| return edits |
|
|
|
|
| def add_labels_to_the_tokens(source_tokens, labels, delimeters=SEQ_DELIMETERS): |
| tokens_with_all_tags = [] |
| source_tokens_with_start = [START_TOKEN] + source_tokens |
| for token, label_list in zip(source_tokens_with_start, labels): |
| all_tags = delimeters['operations'].join(label_list) |
| comb_record = token + delimeters['labels'] + all_tags |
| tokens_with_all_tags.append(comb_record) |
| return delimeters['tokens'].join(tokens_with_all_tags) |
|
|
|
|
| def convert_data_from_raw_files(source_file, target_file, output_file, chunk_size): |
| tagged = [] |
| source_data, target_data = read_parallel_lines(source_file, target_file) |
| print(f"The size of raw dataset is {len(source_data)}") |
| cnt_total, cnt_all, cnt_tp = 0, 0, 0 |
| for source_sent, target_sent in tqdm(zip(source_data, target_data)): |
| try: |
| aligned_sent = align_sequences(source_sent, target_sent) |
| except Exception: |
| aligned_sent = align_sequences(source_sent, target_sent) |
| if source_sent != target_sent: |
| cnt_tp += 1 |
| alignments = [aligned_sent] |
| cnt_all += len(alignments) |
| try: |
| check_sent = convert_tagged_line(aligned_sent) |
| except Exception: |
| |
| aligned_sent = align_sequences(source_sent, target_sent) |
| check_sent = convert_tagged_line(aligned_sent) |
|
|
| if "".join(check_sent.split()) != "".join( |
| target_sent.split()): |
| |
| aligned_sent = align_sequences(source_sent, target_sent) |
| check_sent = convert_tagged_line(aligned_sent) |
| print(f"Incorrect pair: \n{target_sent}\n{check_sent}") |
| continue |
| if alignments: |
| cnt_total += len(alignments) |
| tagged.extend(alignments) |
| if len(tagged) > chunk_size: |
| write_lines(output_file, tagged, mode='a') |
| tagged = [] |
|
|
| print(f"Overall extracted {cnt_total}. " |
| f"Original TP {cnt_tp}." |
| f" Original TN {cnt_all - cnt_tp}") |
| if tagged: |
| write_lines(output_file, tagged, 'a') |
|
|
|
|
| def convert_labels_into_edits(labels): |
| all_edits = [] |
| for i, label_list in enumerate(labels): |
| if label_list == ["$KEEP"]: |
| continue |
| else: |
| edit = [(i - 1, i), label_list] |
| all_edits.append(edit) |
| return all_edits |
|
|
|
|
| def get_target_sent_by_levels(source_tokens, labels): |
| relevant_edits = convert_labels_into_edits(labels) |
| target_tokens = source_tokens[:] |
| leveled_target_tokens = {} |
| if not relevant_edits: |
| target_sentence = " ".join(target_tokens) |
| return leveled_target_tokens, target_sentence |
| max_level = max([len(x[1]) for x in relevant_edits]) |
| for level in range(max_level): |
| rest_edits = [] |
| shift_idx = 0 |
| for edits in relevant_edits: |
| (start, end), label_list = edits |
| label = label_list[0] |
| target_pos = start + shift_idx |
| source_token = target_tokens[target_pos] if target_pos >= 0 else START_TOKEN |
| if label == "$DELETE": |
| del target_tokens[target_pos] |
| shift_idx -= 1 |
| elif label.startswith("$APPEND_"): |
| word = label.replace("$APPEND_", "") |
| target_tokens[target_pos + 1: target_pos + 1] = [word] |
| shift_idx += 1 |
| elif label.startswith("$REPLACE_"): |
| word = label.replace("$REPLACE_", "") |
| target_tokens[target_pos] = word |
| elif label.startswith("$TRANSFORM"): |
| word = apply_reverse_transformation(source_token, label) |
| if word is None: |
| word = source_token |
| target_tokens[target_pos] = word |
| elif label.startswith("$MERGE_"): |
| |
| if level == (max_level - 1): |
| target_tokens[target_pos + 1: target_pos + 1] = [label] |
| shift_idx += 1 |
| else: |
| rest_edit = [(start + shift_idx, end + shift_idx), [label]] |
| rest_edits.append(rest_edit) |
| rest_labels = label_list[1:] |
| if rest_labels: |
| rest_edit = [(start + shift_idx, end + shift_idx), rest_labels] |
| rest_edits.append(rest_edit) |
|
|
| leveled_tokens = target_tokens[:] |
| |
| relevant_edits = rest_edits[:] |
| if level == (max_level - 1): |
| leveled_tokens = replace_merge_transforms(leveled_tokens) |
| leveled_labels = convert_edits_into_labels(leveled_tokens, |
| relevant_edits) |
| leveled_target_tokens[level + 1] = {"tokens": leveled_tokens, |
| "labels": leveled_labels} |
|
|
| target_sentence = " ".join(leveled_target_tokens[max_level]["tokens"]) |
| return leveled_target_tokens, target_sentence |
|
|
|
|
| def replace_merge_transforms(tokens): |
| if all(not x.startswith("$MERGE_") for x in tokens): |
| return tokens |
| target_tokens = tokens[:] |
| allowed_range = (1, len(tokens) - 1) |
| for i in range(len(tokens)): |
| target_token = tokens[i] |
| if target_token.startswith("$MERGE"): |
| if target_token.startswith("$MERGE_SWAP") and i in allowed_range: |
| target_tokens[i - 1] = tokens[i + 1] |
| target_tokens[i + 1] = tokens[i - 1] |
| target_tokens[i: i + 1] = [] |
| target_line = " ".join(target_tokens) |
| target_line = target_line.replace(" $MERGE_HYPHEN ", "-") |
| target_line = target_line.replace(" $MERGE_SPACE ", "") |
| return target_line.split() |
|
|
|
|
| def convert_tagged_line(line, delimeters=SEQ_DELIMETERS): |
| label_del = delimeters['labels'] |
| source_tokens = [x.split(label_del)[0] |
| for x in line.split(delimeters['tokens'])][1:] |
| labels = [x.split(label_del)[1].split(delimeters['operations']) |
| for x in line.split(delimeters['tokens'])] |
| assert len(source_tokens) + 1 == len(labels) |
| levels_dict, target_line = get_target_sent_by_levels(source_tokens, labels) |
| return target_line |
|
|
|
|
| def main(args): |
| convert_data_from_raw_files(args.source, args.target, args.output_file, args.chunk_size) |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('-s', '--source', |
| help='Path to the source file', |
| required=True) |
| parser.add_argument('-t', '--target', |
| help='Path to the target file', |
| required=True) |
| parser.add_argument('-o', '--output_file', |
| help='Path to the output file', |
| required=True) |
| parser.add_argument('--chunk_size', |
| type=int, |
| help='Dump each chunk size.', |
| default=1000000) |
| args = parser.parse_args() |
| main(args) |
|
|