File size: 5,671 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import argparse
import json
import os
import re
import sys
from collections import Counter
"""
Data is output in 4 files:
a file containing the mwt information
a file containing the words and sentences in conllu format
a file containing the raw text of each paragraph
a file of 0,1,2 indicating word break or sentence break on a character level for the raw text
1: end of word
2: end of sentence
"""
PARAGRAPH_BREAK = re.compile(r'\n\s*\n')
def is_para_break(index, text):
""" Detect if a paragraph break can be found, and return the length of the paragraph break sequence. """
if text[index] == '\n':
para_break = PARAGRAPH_BREAK.match(text, index)
if para_break:
break_len = len(para_break.group(0))
return True, break_len
return False, 0
def find_next_word(index, text, word, output):
"""
Locate the next word in the text. In case a paragraph break is found, also write paragraph break to labels.
"""
idx = 0
word_sofar = ''
while index < len(text) and idx < len(word):
para_break, break_len = is_para_break(index, text)
if para_break:
# multiple newlines found, paragraph break
if len(word_sofar) > 0:
assert re.match(r'^\s+$', word_sofar), 'Found non-empty string at the end of a paragraph that doesn\'t match any token: |{}|'.format(word_sofar)
word_sofar = ''
output.write('\n\n')
index += break_len - 1
elif re.match(r'^\s$', text[index]) and not re.match(r'^\s$', word[idx]):
# whitespace found, and whitespace is not part of a word
word_sofar += text[index]
else:
# non-whitespace char, or a whitespace char that's part of a word
word_sofar += text[index]
assert text[index].replace('\n', ' ') == word[idx], "Character mismatch: raw text contains |%s| but the next word is |%s|." % (word_sofar, word)
idx += 1
index += 1
return index, word_sofar
def main(args):
parser = argparse.ArgumentParser()
parser.add_argument('plaintext_file', type=str, help="Plaintext file containing the raw input")
parser.add_argument('conllu_file', type=str, help="CoNLL-U file containing tokens and sentence breaks")
parser.add_argument('-o', '--output', default=None, type=str, help="Output file name; output to the console if not specified (the default)")
parser.add_argument('-m', '--mwt_output', default=None, type=str, help="Output file name for MWT expansions; output to the console if not specified (the default)")
args = parser.parse_args(args=args)
with open(args.plaintext_file, 'r', encoding='utf-8') as f:
text = ''.join(f.readlines())
textlen = len(text)
if args.output is None:
output = sys.stdout
else:
outdir = os.path.split(args.output)[0]
os.makedirs(outdir, exist_ok=True)
output = open(args.output, 'w')
index = 0 # character offset in rawtext
mwt_expansions = []
with open(args.conllu_file, 'r', encoding='utf-8') as f:
buf = ''
mwtbegin = 0
mwtend = -1
expanded = []
last_comments = ""
for line in f:
line = line.strip()
if len(line):
if line[0] == "#":
# comment, don't do anything
if len(last_comments) == 0:
last_comments = line
continue
line = line.split('\t')
if '.' in line[0]:
# the tokenizer doesn't deal with ellipsis
continue
word = line[1]
if '-' in line[0]:
# multiword token
mwtbegin, mwtend = [int(x) for x in line[0].split('-')]
lastmwt = word
expanded = []
elif mwtbegin <= int(line[0]) < mwtend:
expanded += [word]
continue
elif int(line[0]) == mwtend:
expanded += [word]
expanded = [x.lower() for x in expanded] # evaluation doesn't care about case
mwt_expansions += [(lastmwt, tuple(expanded))]
if lastmwt[0].islower() and not expanded[0][0].islower():
print('Sentence ID with potential wrong MWT expansion: ', last_comments, file=sys.stderr)
mwtbegin = 0
mwtend = -1
lastmwt = None
continue
if len(buf):
output.write(buf)
index, word_found = find_next_word(index, text, word, output)
buf = '0' * (len(word_found)-1) + ('1' if '-' not in line[0] else '3')
else:
# sentence break found
if len(buf):
assert int(buf[-1]) >= 1
output.write(buf[:-1] + '{}'.format(int(buf[-1]) + 1))
buf = ''
last_comments = ''
status_line = ""
if args.output:
output.close()
status_line = 'Tokenizer labels written to %s\n ' % args.output
mwts = Counter(mwt_expansions)
if args.mwt_output is None:
print('MWTs:', mwts)
else:
with open(args.mwt_output, 'w') as f:
json.dump(list(mwts.items()), f, indent=2)
status_line = status_line + '{} unique MWTs found in data. MWTs written to {}'.format(len(mwts), args.mwt_output)
print(status_line)
if __name__ == '__main__':
main(sys.argv[1:])
|