tyfsadik commited on
Commit
88cda3e
·
verified ·
1 Parent(s): e59b6d4

Upload 4 files

Browse files
utils/filter_brackets.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import re
3
+
4
+ from helpers import write_lines
5
+
6
+
7
+ def filter_line(line):
8
+ if "-LRB-" in line and "-RRB-" in line:
9
+ rep = re.sub(r'\-.*?LRB.*?\-.*?\-.*?RRB.*?\-', '', line)
10
+ line_cleaned = rep
11
+ elif ("-LRB-" in line and "-RRB-" not in line) or (
12
+ "-LRB-" not in line and "-RRB-" in line):
13
+ line_cleaned = line.replace("-LRB-", '"').replace("-RRB-", '"')
14
+ else:
15
+ line_cleaned = line
16
+ return line_cleaned
17
+
18
+
19
+ def main(args):
20
+ with open(args.source) as f:
21
+ data = [row.rstrip() for row in f]
22
+
23
+ write_lines(args.output, [filter_line(row) for row in data])
24
+
25
+
26
+ if __name__ == '__main__':
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument('-s', '--source',
29
+ help='Path to the source file',
30
+ required=True)
31
+ parser.add_argument('-o', '--output',
32
+ help='Path to the output file',
33
+ required=True)
34
+ args = parser.parse_args()
35
+ main(args)
utils/helpers.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+
5
+ VOCAB_DIR = Path(__file__).resolve().parent.parent / "data"
6
+ PAD = "@@PADDING@@"
7
+ UNK = "@@UNKNOWN@@"
8
+ START_TOKEN = "$START"
9
+ SEQ_DELIMETERS = {"tokens": " ",
10
+ "labels": "SEPL|||SEPR",
11
+ "operations": "SEPL__SEPR"}
12
+ REPLACEMENTS = {
13
+ "''": '"',
14
+ '--': '—',
15
+ '`': "'",
16
+ "'ve": "' ve",
17
+ }
18
+
19
+
20
+ def get_verb_form_dicts():
21
+ path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt")
22
+ encode, decode = {}, {}
23
+ with open(path_to_dict, encoding="utf-8") as f:
24
+ for line in f:
25
+ words, tags = line.split(":")
26
+ word1, word2 = words.split("_")
27
+ tag1, tag2 = tags.split("_")
28
+ decode_key = f"{word1}_{tag1}_{tag2.strip()}"
29
+ if decode_key not in decode:
30
+ encode[words] = tags
31
+ decode[decode_key] = word2
32
+ return encode, decode
33
+
34
+
35
+ ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts()
36
+
37
+
38
+ def get_target_sent_by_edits(source_tokens, edits):
39
+ target_tokens = source_tokens[:]
40
+ shift_idx = 0
41
+ for edit in edits:
42
+ start, end, label, _ = edit
43
+ target_pos = start + shift_idx
44
+ source_token = target_tokens[target_pos] \
45
+ if len(target_tokens) > target_pos >= 0 else ''
46
+ if label == "":
47
+ del target_tokens[target_pos]
48
+ shift_idx -= 1
49
+ elif start == end:
50
+ word = label.replace("$APPEND_", "")
51
+ target_tokens[target_pos: target_pos] = [word]
52
+ shift_idx += 1
53
+ elif label.startswith("$TRANSFORM_"):
54
+ word = apply_reverse_transformation(source_token, label)
55
+ if word is None:
56
+ word = source_token
57
+ target_tokens[target_pos] = word
58
+ elif start == end - 1:
59
+ word = label.replace("$REPLACE_", "")
60
+ target_tokens[target_pos] = word
61
+ elif label.startswith("$MERGE_"):
62
+ target_tokens[target_pos + 1: target_pos + 1] = [label]
63
+ shift_idx += 1
64
+
65
+ return replace_merge_transforms(target_tokens)
66
+
67
+
68
+ def replace_merge_transforms(tokens):
69
+ if all(not x.startswith("$MERGE_") for x in tokens):
70
+ return tokens
71
+
72
+ target_line = " ".join(tokens)
73
+ target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
74
+ target_line = target_line.replace(" $MERGE_SPACE ", "")
75
+ return target_line.split()
76
+
77
+
78
+ def convert_using_case(token, smart_action):
79
+ if not smart_action.startswith("$TRANSFORM_CASE_"):
80
+ return token
81
+ if smart_action.endswith("LOWER"):
82
+ return token.lower()
83
+ elif smart_action.endswith("UPPER"):
84
+ return token.upper()
85
+ elif smart_action.endswith("CAPITAL"):
86
+ return token.capitalize()
87
+ elif smart_action.endswith("CAPITAL_1"):
88
+ return token[0] + token[1:].capitalize()
89
+ elif smart_action.endswith("UPPER_-1"):
90
+ return token[:-1].upper() + token[-1]
91
+ else:
92
+ return token
93
+
94
+
95
+ def convert_using_verb(token, smart_action):
96
+ key_word = "$TRANSFORM_VERB_"
97
+ if not smart_action.startswith(key_word):
98
+ raise Exception(f"Unknown action type {smart_action}")
99
+ encoding_part = f"{token}_{smart_action[len(key_word):]}"
100
+ decoded_target_word = decode_verb_form(encoding_part)
101
+ return decoded_target_word
102
+
103
+
104
+ def convert_using_split(token, smart_action):
105
+ key_word = "$TRANSFORM_SPLIT"
106
+ if not smart_action.startswith(key_word):
107
+ raise Exception(f"Unknown action type {smart_action}")
108
+ target_words = token.split("-")
109
+ return " ".join(target_words)
110
+
111
+
112
+ def convert_using_plural(token, smart_action):
113
+ if smart_action.endswith("PLURAL"):
114
+ return token + "s"
115
+ elif smart_action.endswith("SINGULAR"):
116
+ return token[:-1]
117
+ else:
118
+ raise Exception(f"Unknown action type {smart_action}")
119
+
120
+
121
+ def apply_reverse_transformation(source_token, transform):
122
+ if transform.startswith("$TRANSFORM"):
123
+ # deal with equal
124
+ if transform == "$KEEP":
125
+ return source_token
126
+ # deal with case
127
+ if transform.startswith("$TRANSFORM_CASE"):
128
+ return convert_using_case(source_token, transform)
129
+ # deal with verb
130
+ if transform.startswith("$TRANSFORM_VERB"):
131
+ return convert_using_verb(source_token, transform)
132
+ # deal with split
133
+ if transform.startswith("$TRANSFORM_SPLIT"):
134
+ return convert_using_split(source_token, transform)
135
+ # deal with single/plural
136
+ if transform.startswith("$TRANSFORM_AGREEMENT"):
137
+ return convert_using_plural(source_token, transform)
138
+ # raise exception if not find correct type
139
+ raise Exception(f"Unknown action type {transform}")
140
+ else:
141
+ return source_token
142
+
143
+
144
+ def read_parallel_lines(fn1, fn2):
145
+ lines1 = read_lines(fn1, skip_strip=True)
146
+ lines2 = read_lines(fn2, skip_strip=True)
147
+ assert len(lines1) == len(lines2)
148
+ out_lines1, out_lines2 = [], []
149
+ for line1, line2 in zip(lines1, lines2):
150
+ if not line1.strip() or not line2.strip():
151
+ continue
152
+ else:
153
+ out_lines1.append(line1)
154
+ out_lines2.append(line2)
155
+ return out_lines1, out_lines2
156
+
157
+
158
+ def read_lines(fn, skip_strip=False):
159
+ if not os.path.exists(fn):
160
+ return []
161
+ with open(fn, 'r', encoding='utf-8') as f:
162
+ lines = f.readlines()
163
+ return [s.strip() for s in lines if s.strip() or skip_strip]
164
+
165
+
166
+ def write_lines(fn, lines, mode='w'):
167
+ if mode == 'w' and os.path.exists(fn):
168
+ os.remove(fn)
169
+ with open(fn, encoding='utf-8', mode=mode) as f:
170
+ f.writelines(['%s\n' % s for s in lines])
171
+
172
+
173
+ def decode_verb_form(original):
174
+ return DECODE_VERB_DICT.get(original)
175
+
176
+
177
+ def encode_verb_form(original_word, corrected_word):
178
+ decoding_request = original_word + "_" + corrected_word
179
+ decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip()
180
+ if original_word and decoding_response:
181
+ answer = decoding_response
182
+ else:
183
+ answer = None
184
+ return answer
185
+
186
+
187
+ def get_weights_name(transformer_name, lowercase):
188
+ if transformer_name == 'bert' and lowercase:
189
+ return 'bert-base-uncased'
190
+ if transformer_name == 'bert' and not lowercase:
191
+ return 'bert-base-cased'
192
+ if transformer_name == 'bert-large' and not lowercase:
193
+ return 'bert-large-cased'
194
+ if transformer_name == 'distilbert':
195
+ if not lowercase:
196
+ print('Warning! This model was trained only on uncased sentences.')
197
+ return 'distilbert-base-uncased'
198
+ if transformer_name == 'albert':
199
+ if not lowercase:
200
+ print('Warning! This model was trained only on uncased sentences.')
201
+ return 'albert-base-v1'
202
+ if lowercase:
203
+ print('Warning! This model was trained only on cased sentences.')
204
+ if transformer_name == 'roberta':
205
+ return 'roberta-base'
206
+ if transformer_name == 'roberta-large':
207
+ return 'roberta-large'
208
+ if transformer_name == 'gpt2':
209
+ return 'gpt2'
210
+ if transformer_name == 'transformerxl':
211
+ return 'transfo-xl-wt103'
212
+ if transformer_name == 'xlnet':
213
+ return 'xlnet-base-cased'
214
+ if transformer_name == 'xlnet-large':
215
+ return 'xlnet-large-cased'
216
+
217
+
218
+ def remove_double_tokens(sent):
219
+ tokens = sent.split(' ')
220
+ deleted_idx = []
221
+ for i in range(len(tokens) -1):
222
+ if tokens[i] == tokens[i + 1]:
223
+ deleted_idx.append(i + 1)
224
+ if deleted_idx:
225
+ tokens = [tokens[i] for i in range(len(tokens)) if i not in deleted_idx]
226
+ return ' '.join(tokens)
227
+
228
+
229
+ def normalize(sent):
230
+ sent = remove_double_tokens(sent)
231
+ for fr, to in REPLACEMENTS.items():
232
+ sent = sent.replace(fr, to)
233
+ return sent.lower()
utils/prepare_clc_fce_data.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Convert CLC-FCE dataset (The Cambridge Learner Corpus) to the parallel sentences format.
4
+ """
5
+
6
+ import argparse
7
+ import glob
8
+ import os
9
+ import re
10
+ from xml.etree import cElementTree
11
+
12
+ from nltk.tokenize import sent_tokenize, word_tokenize
13
+ from tqdm import tqdm
14
+
15
+
16
+ def annotate_fce_doc(xml):
17
+ """Takes a FCE xml document and yields sentences with annotated errors."""
18
+ result = []
19
+ doc = cElementTree.fromstring(xml)
20
+ paragraphs = doc.findall('head/text/*/coded_answer/p')
21
+ for p in paragraphs:
22
+ text = _get_formatted_text(p)
23
+ result.append(text)
24
+
25
+ return '\n'.join(result)
26
+
27
+
28
+ def _get_formatted_text(elem, ignore_tags=None):
29
+ text = elem.text or ''
30
+ ignore_tags = [tag.upper() for tag in (ignore_tags or [])]
31
+ correct = None
32
+ mistake = None
33
+
34
+ for child in elem.getchildren():
35
+ tag = child.tag.upper()
36
+ if tag == 'NS':
37
+ text += _get_formatted_text(child)
38
+
39
+ elif tag == 'UNKNOWN':
40
+ text += ' UNKNOWN '
41
+
42
+ elif tag == 'C':
43
+ assert correct is None
44
+ correct = _get_formatted_text(child)
45
+
46
+ elif tag == 'I':
47
+ assert mistake is None
48
+ mistake = _get_formatted_text(child)
49
+
50
+ elif tag in ignore_tags:
51
+ pass
52
+
53
+ else:
54
+ raise ValueError(f"Unknown tag `{child.tag}`", text)
55
+
56
+ if correct or mistake:
57
+ correct = correct or ''
58
+ mistake = mistake or ''
59
+ if '=>' not in mistake:
60
+ text += f'{{{mistake}=>{correct}}}'
61
+ else:
62
+ text += mistake
63
+
64
+ text += elem.tail or ''
65
+ return text
66
+
67
+
68
+ def convert_fce(fce_dir):
69
+ """Processes the whole FCE directory. Yields annotated documents (strings)."""
70
+
71
+ # Ensure we got the valid dataset path
72
+ if not os.path.isdir(fce_dir):
73
+ raise UserWarning(
74
+ f"{fce_dir} is not a valid path")
75
+
76
+ dataset_dir = os.path.join(fce_dir, 'dataset')
77
+ if not os.path.exists(dataset_dir):
78
+ raise UserWarning(
79
+ f"{fce_dir} doesn't point to a dataset's root dir")
80
+
81
+ # Convert XML docs to the corpora format
82
+ filenames = sorted(glob.glob(os.path.join(dataset_dir, '*/*.xml')))
83
+
84
+ docs = []
85
+ for filename in filenames:
86
+ with open(filename, encoding='utf-8') as f:
87
+ doc = annotate_fce_doc(f.read())
88
+ docs.append(doc)
89
+ return docs
90
+
91
+
92
+ def main():
93
+ fce = convert_fce(args.fce_dataset_path)
94
+ with open(args.output + "/fce-original.txt", 'w', encoding='utf-8') as out_original, \
95
+ open(args.output + "/fce-applied.txt", 'w', encoding='utf-8') as out_applied:
96
+ for doc in tqdm(fce, unit='doc'):
97
+ sents = re.split(r"\n +\n", doc)
98
+ for sent in sents:
99
+ tokenized_sents = sent_tokenize(sent)
100
+ for i in range(len(tokenized_sents)):
101
+ if re.search(r"[{>][.?!]$", tokenized_sents[i]):
102
+ tokenized_sents[i + 1] = tokenized_sents[i] + " " + tokenized_sents[i + 1]
103
+ tokenized_sents[i] = ""
104
+ regexp = r'{([^{}]*?)=>([^{}]*?)}'
105
+ original = re.sub(regexp, r"\1", tokenized_sents[i])
106
+ applied = re.sub(regexp, r"\2", tokenized_sents[i])
107
+ # filter out nested alerts
108
+ if original != "" and applied != "" and not re.search(r"[{}=]", original) \
109
+ and not re.search(r"[{}=]", applied):
110
+ out_original.write(" ".join(word_tokenize(original)) + "\n")
111
+ out_applied.write(" ".join(word_tokenize(applied)) + "\n")
112
+
113
+
114
+ if __name__ == '__main__':
115
+ parser = argparse.ArgumentParser(description=(
116
+ "Convert CLC-FCE dataset to the parallel sentences format."))
117
+ parser.add_argument('fce_dataset_path',
118
+ help='Path to the folder with the FCE dataset')
119
+ parser.add_argument('--output',
120
+ help='Path to the output folder')
121
+ args = parser.parse_args()
122
+
123
+ main()
utils/preprocess_data.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from difflib import SequenceMatcher
4
+
5
+ import Levenshtein
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from helpers import write_lines, read_parallel_lines, encode_verb_form, \
10
+ apply_reverse_transformation, SEQ_DELIMETERS, START_TOKEN
11
+
12
+
13
+ def perfect_align(t, T, insertions_allowed=0,
14
+ cost_function=Levenshtein.distance):
15
+ # dp[i, j, k] is a minimal cost of matching first `i` tokens of `t` with
16
+ # first `j` tokens of `T`, after making `k` insertions after last match of
17
+ # token from `t`. In other words t[:i] aligned with T[:j].
18
+
19
+ # Initialize with INFINITY (unknown)
20
+ shape = (len(t) + 1, len(T) + 1, insertions_allowed + 1)
21
+ dp = np.ones(shape, dtype=int) * int(1e9)
22
+ come_from = np.ones(shape, dtype=int) * int(1e9)
23
+ come_from_ins = np.ones(shape, dtype=int) * int(1e9)
24
+
25
+ dp[0, 0, 0] = 0 # The only known starting point. Nothing matched to nothing.
26
+ for i in range(len(t) + 1): # Go inclusive
27
+ for j in range(len(T) + 1): # Go inclusive
28
+ for q in range(insertions_allowed + 1): # Go inclusive
29
+ if i < len(t):
30
+ # Given matched sequence of t[:i] and T[:j], match token
31
+ # t[i] with following tokens T[j:k].
32
+ for k in range(j, len(T) + 1):
33
+ transform = \
34
+ apply_transformation(t[i], ' '.join(T[j:k]))
35
+ if transform:
36
+ cost = 0
37
+ else:
38
+ cost = cost_function(t[i], ' '.join(T[j:k]))
39
+ current = dp[i, j, q] + cost
40
+ if dp[i + 1, k, 0] > current:
41
+ dp[i + 1, k, 0] = current
42
+ come_from[i + 1, k, 0] = j
43
+ come_from_ins[i + 1, k, 0] = q
44
+ if q < insertions_allowed:
45
+ # Given matched sequence of t[:i] and T[:j], create
46
+ # insertion with following tokens T[j:k].
47
+ for k in range(j, len(T) + 1):
48
+ cost = len(' '.join(T[j:k]))
49
+ current = dp[i, j, q] + cost
50
+ if dp[i, k, q + 1] > current:
51
+ dp[i, k, q + 1] = current
52
+ come_from[i, k, q + 1] = j
53
+ come_from_ins[i, k, q + 1] = q
54
+
55
+ # Solution is in the dp[len(t), len(T), *]. Backtracking from there.
56
+ alignment = []
57
+ i = len(t)
58
+ j = len(T)
59
+ q = dp[i, j, :].argmin()
60
+ while i > 0 or q > 0:
61
+ is_insert = (come_from_ins[i, j, q] != q) and (q != 0)
62
+ j, k, q = come_from[i, j, q], j, come_from_ins[i, j, q]
63
+ if not is_insert:
64
+ i -= 1
65
+
66
+ if is_insert:
67
+ alignment.append(['INSERT', T[j:k], (i, i)])
68
+ else:
69
+ alignment.append([f'REPLACE_{t[i]}', T[j:k], (i, i + 1)])
70
+
71
+ assert j == 0
72
+
73
+ return dp[len(t), len(T)].min(), list(reversed(alignment))
74
+
75
+
76
+ def _split(token):
77
+ if not token:
78
+ return []
79
+ parts = token.split()
80
+ return parts or [token]
81
+
82
+
83
+ def apply_merge_transformation(source_tokens, target_words, shift_idx):
84
+ edits = []
85
+ if len(source_tokens) > 1 and len(target_words) == 1:
86
+ # check merge
87
+ transform = check_merge(source_tokens, target_words)
88
+ if transform:
89
+ for i in range(len(source_tokens) - 1):
90
+ edits.append([(shift_idx + i, shift_idx + i + 1), transform])
91
+ return edits
92
+
93
+ if len(source_tokens) == len(target_words) == 2:
94
+ # check swap
95
+ transform = check_swap(source_tokens, target_words)
96
+ if transform:
97
+ edits.append([(shift_idx, shift_idx + 1), transform])
98
+ return edits
99
+
100
+
101
+ def is_sent_ok(sent, delimeters=SEQ_DELIMETERS):
102
+ for del_val in delimeters.values():
103
+ if del_val in sent and del_val != delimeters["tokens"]:
104
+ return False
105
+ return True
106
+
107
+
108
+ def check_casetype(source_token, target_token):
109
+ if source_token.lower() != target_token.lower():
110
+ return None
111
+ if source_token.lower() == target_token:
112
+ return "$TRANSFORM_CASE_LOWER"
113
+ elif source_token.capitalize() == target_token:
114
+ return "$TRANSFORM_CASE_CAPITAL"
115
+ elif source_token.upper() == target_token:
116
+ return "$TRANSFORM_CASE_UPPER"
117
+ elif source_token[1:].capitalize() == target_token[1:] and source_token[0] == target_token[0]:
118
+ return "$TRANSFORM_CASE_CAPITAL_1"
119
+ elif source_token[:-1].upper() == target_token[:-1] and source_token[-1] == target_token[-1]:
120
+ return "$TRANSFORM_CASE_UPPER_-1"
121
+ else:
122
+ return None
123
+
124
+
125
+ def check_equal(source_token, target_token):
126
+ if source_token == target_token:
127
+ return "$KEEP"
128
+ else:
129
+ return None
130
+
131
+
132
+ def check_split(source_token, target_tokens):
133
+ if source_token.split("-") == target_tokens:
134
+ return "$TRANSFORM_SPLIT_HYPHEN"
135
+ else:
136
+ return None
137
+
138
+
139
+ def check_merge(source_tokens, target_tokens):
140
+ if "".join(source_tokens) == "".join(target_tokens):
141
+ return "$MERGE_SPACE"
142
+ elif "-".join(source_tokens) == "-".join(target_tokens):
143
+ return "$MERGE_HYPHEN"
144
+ else:
145
+ return None
146
+
147
+
148
+ def check_swap(source_tokens, target_tokens):
149
+ if source_tokens == [x for x in reversed(target_tokens)]:
150
+ return "$MERGE_SWAP"
151
+ else:
152
+ return None
153
+
154
+
155
+ def check_plural(source_token, target_token):
156
+ if source_token.endswith("s") and source_token[:-1] == target_token:
157
+ return "$TRANSFORM_AGREEMENT_SINGULAR"
158
+ elif target_token.endswith("s") and source_token == target_token[:-1]:
159
+ return "$TRANSFORM_AGREEMENT_PLURAL"
160
+ else:
161
+ return None
162
+
163
+
164
+ def check_verb(source_token, target_token):
165
+ encoding = encode_verb_form(source_token, target_token)
166
+ if encoding:
167
+ return f"$TRANSFORM_VERB_{encoding}"
168
+ else:
169
+ return None
170
+
171
+
172
+ def apply_transformation(source_token, target_token):
173
+ target_tokens = target_token.split()
174
+ if len(target_tokens) > 1:
175
+ # check split
176
+ transform = check_split(source_token, target_tokens)
177
+ if transform:
178
+ return transform
179
+ checks = [check_equal, check_casetype, check_verb, check_plural]
180
+ for check in checks:
181
+ transform = check(source_token, target_token)
182
+ if transform:
183
+ return transform
184
+ return None
185
+
186
+
187
+ def align_sequences(source_sent, target_sent):
188
+ # check if sent is OK
189
+ if not is_sent_ok(source_sent) or not is_sent_ok(target_sent):
190
+ return None
191
+ source_tokens = source_sent.split()
192
+ target_tokens = target_sent.split()
193
+ matcher = SequenceMatcher(None, source_tokens, target_tokens)
194
+ diffs = list(matcher.get_opcodes())
195
+ all_edits = []
196
+ for diff in diffs:
197
+ tag, i1, i2, j1, j2 = diff
198
+ source_part = _split(" ".join(source_tokens[i1:i2]))
199
+ target_part = _split(" ".join(target_tokens[j1:j2]))
200
+ if tag == 'equal':
201
+ continue
202
+ elif tag == 'delete':
203
+ # delete all words separatly
204
+ for j in range(i2 - i1):
205
+ edit = [(i1 + j, i1 + j + 1), '$DELETE']
206
+ all_edits.append(edit)
207
+ elif tag == 'insert':
208
+ # append to the previous word
209
+ for target_token in target_part:
210
+ edit = ((i1 - 1, i1), f"$APPEND_{target_token}")
211
+ all_edits.append(edit)
212
+ else:
213
+ # check merge first of all
214
+ edits = apply_merge_transformation(source_part, target_part,
215
+ shift_idx=i1)
216
+ if edits:
217
+ all_edits.extend(edits)
218
+ continue
219
+
220
+ # normalize alignments if need (make them singleton)
221
+ _, alignments = perfect_align(source_part, target_part,
222
+ insertions_allowed=0)
223
+ for alignment in alignments:
224
+ new_shift = alignment[2][0]
225
+ edits = convert_alignments_into_edits(alignment,
226
+ shift_idx=i1 + new_shift)
227
+ all_edits.extend(edits)
228
+
229
+ # get labels
230
+ labels = convert_edits_into_labels(source_tokens, all_edits)
231
+ # match tags to source tokens
232
+ sent_with_tags = add_labels_to_the_tokens(source_tokens, labels)
233
+ return sent_with_tags
234
+
235
+
236
+ def convert_edits_into_labels(source_tokens, all_edits):
237
+ # make sure that edits are flat
238
+ flat_edits = []
239
+ for edit in all_edits:
240
+ (start, end), edit_operations = edit
241
+ if isinstance(edit_operations, list):
242
+ for operation in edit_operations:
243
+ new_edit = [(start, end), operation]
244
+ flat_edits.append(new_edit)
245
+ elif isinstance(edit_operations, str):
246
+ flat_edits.append(edit)
247
+ else:
248
+ raise Exception("Unknown operation type")
249
+ all_edits = flat_edits[:]
250
+ labels = []
251
+ total_labels = len(source_tokens) + 1
252
+ if not all_edits:
253
+ labels = [["$KEEP"] for x in range(total_labels)]
254
+ else:
255
+ for i in range(total_labels):
256
+ edit_operations = [x[1] for x in all_edits if x[0][0] == i - 1
257
+ and x[0][1] == i]
258
+ if not edit_operations:
259
+ labels.append(["$KEEP"])
260
+ else:
261
+ labels.append(edit_operations)
262
+ return labels
263
+
264
+
265
+ def convert_alignments_into_edits(alignment, shift_idx):
266
+ edits = []
267
+ action, target_tokens, new_idx = alignment
268
+ source_token = action.replace("REPLACE_", "")
269
+
270
+ # check if delete
271
+ if not target_tokens:
272
+ edit = [(shift_idx, 1 + shift_idx), "$DELETE"]
273
+ return [edit]
274
+
275
+ # check splits
276
+ for i in range(1, len(target_tokens)):
277
+ target_token = " ".join(target_tokens[:i + 1])
278
+ transform = apply_transformation(source_token, target_token)
279
+ if transform:
280
+ edit = [(shift_idx, shift_idx + 1), transform]
281
+ edits.append(edit)
282
+ target_tokens = target_tokens[i + 1:]
283
+ for target in target_tokens:
284
+ edits.append([(shift_idx, shift_idx + 1), f"$APPEND_{target}"])
285
+ return edits
286
+
287
+ transform_costs = []
288
+ transforms = []
289
+ for target_token in target_tokens:
290
+ transform = apply_transformation(source_token, target_token)
291
+ if transform:
292
+ cost = 0
293
+ transforms.append(transform)
294
+ else:
295
+ cost = Levenshtein.distance(source_token, target_token)
296
+ transforms.append(None)
297
+ transform_costs.append(cost)
298
+ min_cost_idx = transform_costs.index(min(transform_costs))
299
+ # append to the previous word
300
+ for i in range(0, min_cost_idx):
301
+ target = target_tokens[i]
302
+ edit = [(shift_idx - 1, shift_idx), f"$APPEND_{target}"]
303
+ edits.append(edit)
304
+ # replace/transform target word
305
+ transform = transforms[min_cost_idx]
306
+ target = transform if transform is not None \
307
+ else f"$REPLACE_{target_tokens[min_cost_idx]}"
308
+ edit = [(shift_idx, 1 + shift_idx), target]
309
+ edits.append(edit)
310
+ # append to this word
311
+ for i in range(min_cost_idx + 1, len(target_tokens)):
312
+ target = target_tokens[i]
313
+ edit = [(shift_idx, 1 + shift_idx), f"$APPEND_{target}"]
314
+ edits.append(edit)
315
+ return edits
316
+
317
+
318
+ def add_labels_to_the_tokens(source_tokens, labels, delimeters=SEQ_DELIMETERS):
319
+ tokens_with_all_tags = []
320
+ source_tokens_with_start = [START_TOKEN] + source_tokens
321
+ for token, label_list in zip(source_tokens_with_start, labels):
322
+ all_tags = delimeters['operations'].join(label_list)
323
+ comb_record = token + delimeters['labels'] + all_tags
324
+ tokens_with_all_tags.append(comb_record)
325
+ return delimeters['tokens'].join(tokens_with_all_tags)
326
+
327
+
328
+ def convert_data_from_raw_files(source_file, target_file, output_file, chunk_size):
329
+ tagged = []
330
+ source_data, target_data = read_parallel_lines(source_file, target_file)
331
+ print(f"The size of raw dataset is {len(source_data)}")
332
+ cnt_total, cnt_all, cnt_tp = 0, 0, 0
333
+ for source_sent, target_sent in tqdm(zip(source_data, target_data)):
334
+ try:
335
+ aligned_sent = align_sequences(source_sent, target_sent)
336
+ except Exception:
337
+ aligned_sent = align_sequences(source_sent, target_sent)
338
+ if source_sent != target_sent:
339
+ cnt_tp += 1
340
+ alignments = [aligned_sent]
341
+ cnt_all += len(alignments)
342
+ try:
343
+ check_sent = convert_tagged_line(aligned_sent)
344
+ except Exception:
345
+ # debug mode
346
+ aligned_sent = align_sequences(source_sent, target_sent)
347
+ check_sent = convert_tagged_line(aligned_sent)
348
+
349
+ if "".join(check_sent.split()) != "".join(
350
+ target_sent.split()):
351
+ # do it again for debugging
352
+ aligned_sent = align_sequences(source_sent, target_sent)
353
+ check_sent = convert_tagged_line(aligned_sent)
354
+ print(f"Incorrect pair: \n{target_sent}\n{check_sent}")
355
+ continue
356
+ if alignments:
357
+ cnt_total += len(alignments)
358
+ tagged.extend(alignments)
359
+ if len(tagged) > chunk_size:
360
+ write_lines(output_file, tagged, mode='a')
361
+ tagged = []
362
+
363
+ print(f"Overall extracted {cnt_total}. "
364
+ f"Original TP {cnt_tp}."
365
+ f" Original TN {cnt_all - cnt_tp}")
366
+ if tagged:
367
+ write_lines(output_file, tagged, 'a')
368
+
369
+
370
+ def convert_labels_into_edits(labels):
371
+ all_edits = []
372
+ for i, label_list in enumerate(labels):
373
+ if label_list == ["$KEEP"]:
374
+ continue
375
+ else:
376
+ edit = [(i - 1, i), label_list]
377
+ all_edits.append(edit)
378
+ return all_edits
379
+
380
+
381
+ def get_target_sent_by_levels(source_tokens, labels):
382
+ relevant_edits = convert_labels_into_edits(labels)
383
+ target_tokens = source_tokens[:]
384
+ leveled_target_tokens = {}
385
+ if not relevant_edits:
386
+ target_sentence = " ".join(target_tokens)
387
+ return leveled_target_tokens, target_sentence
388
+ max_level = max([len(x[1]) for x in relevant_edits])
389
+ for level in range(max_level):
390
+ rest_edits = []
391
+ shift_idx = 0
392
+ for edits in relevant_edits:
393
+ (start, end), label_list = edits
394
+ label = label_list[0]
395
+ target_pos = start + shift_idx
396
+ source_token = target_tokens[target_pos] if target_pos >= 0 else START_TOKEN
397
+ if label == "$DELETE":
398
+ del target_tokens[target_pos]
399
+ shift_idx -= 1
400
+ elif label.startswith("$APPEND_"):
401
+ word = label.replace("$APPEND_", "")
402
+ target_tokens[target_pos + 1: target_pos + 1] = [word]
403
+ shift_idx += 1
404
+ elif label.startswith("$REPLACE_"):
405
+ word = label.replace("$REPLACE_", "")
406
+ target_tokens[target_pos] = word
407
+ elif label.startswith("$TRANSFORM"):
408
+ word = apply_reverse_transformation(source_token, label)
409
+ if word is None:
410
+ word = source_token
411
+ target_tokens[target_pos] = word
412
+ elif label.startswith("$MERGE_"):
413
+ # apply merge only on last stage
414
+ if level == (max_level - 1):
415
+ target_tokens[target_pos + 1: target_pos + 1] = [label]
416
+ shift_idx += 1
417
+ else:
418
+ rest_edit = [(start + shift_idx, end + shift_idx), [label]]
419
+ rest_edits.append(rest_edit)
420
+ rest_labels = label_list[1:]
421
+ if rest_labels:
422
+ rest_edit = [(start + shift_idx, end + shift_idx), rest_labels]
423
+ rest_edits.append(rest_edit)
424
+
425
+ leveled_tokens = target_tokens[:]
426
+ # update next step
427
+ relevant_edits = rest_edits[:]
428
+ if level == (max_level - 1):
429
+ leveled_tokens = replace_merge_transforms(leveled_tokens)
430
+ leveled_labels = convert_edits_into_labels(leveled_tokens,
431
+ relevant_edits)
432
+ leveled_target_tokens[level + 1] = {"tokens": leveled_tokens,
433
+ "labels": leveled_labels}
434
+
435
+ target_sentence = " ".join(leveled_target_tokens[max_level]["tokens"])
436
+ return leveled_target_tokens, target_sentence
437
+
438
+
439
+ def replace_merge_transforms(tokens):
440
+ if all(not x.startswith("$MERGE_") for x in tokens):
441
+ return tokens
442
+ target_tokens = tokens[:]
443
+ allowed_range = (1, len(tokens) - 1)
444
+ for i in range(len(tokens)):
445
+ target_token = tokens[i]
446
+ if target_token.startswith("$MERGE"):
447
+ if target_token.startswith("$MERGE_SWAP") and i in allowed_range:
448
+ target_tokens[i - 1] = tokens[i + 1]
449
+ target_tokens[i + 1] = tokens[i - 1]
450
+ target_tokens[i: i + 1] = []
451
+ target_line = " ".join(target_tokens)
452
+ target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
453
+ target_line = target_line.replace(" $MERGE_SPACE ", "")
454
+ return target_line.split()
455
+
456
+
457
+ def convert_tagged_line(line, delimeters=SEQ_DELIMETERS):
458
+ label_del = delimeters['labels']
459
+ source_tokens = [x.split(label_del)[0]
460
+ for x in line.split(delimeters['tokens'])][1:]
461
+ labels = [x.split(label_del)[1].split(delimeters['operations'])
462
+ for x in line.split(delimeters['tokens'])]
463
+ assert len(source_tokens) + 1 == len(labels)
464
+ levels_dict, target_line = get_target_sent_by_levels(source_tokens, labels)
465
+ return target_line
466
+
467
+
468
+ def main(args):
469
+ convert_data_from_raw_files(args.source, args.target, args.output_file, args.chunk_size)
470
+
471
+
472
+ if __name__ == '__main__':
473
+ parser = argparse.ArgumentParser()
474
+ parser.add_argument('-s', '--source',
475
+ help='Path to the source file',
476
+ required=True)
477
+ parser.add_argument('-t', '--target',
478
+ help='Path to the target file',
479
+ required=True)
480
+ parser.add_argument('-o', '--output_file',
481
+ help='Path to the output file',
482
+ required=True)
483
+ parser.add_argument('--chunk_size',
484
+ type=int,
485
+ help='Dump each chunk size.',
486
+ default=1000000)
487
+ args = parser.parse_args()
488
+ main(args)