stanza-digphil / stanza /utils /datasets /prepare_tokenizer_data.py
Albin Thörn Cleland
Clean initial commit with LFS
19b8775
import argparse
import json
import os
import re
import sys
from collections import Counter
"""
Data is output in 4 files:
a file containing the mwt information
a file containing the words and sentences in conllu format
a file containing the raw text of each paragraph
a file of 0,1,2 indicating word break or sentence break on a character level for the raw text
1: end of word
2: end of sentence
"""
PARAGRAPH_BREAK = re.compile(r'\n\s*\n')
def is_para_break(index, text):
""" Detect if a paragraph break can be found, and return the length of the paragraph break sequence. """
if text[index] == '\n':
para_break = PARAGRAPH_BREAK.match(text, index)
if para_break:
break_len = len(para_break.group(0))
return True, break_len
return False, 0
def find_next_word(index, text, word, output):
"""
Locate the next word in the text. In case a paragraph break is found, also write paragraph break to labels.
"""
idx = 0
word_sofar = ''
while index < len(text) and idx < len(word):
para_break, break_len = is_para_break(index, text)
if para_break:
# multiple newlines found, paragraph break
if len(word_sofar) > 0:
assert re.match(r'^\s+$', word_sofar), 'Found non-empty string at the end of a paragraph that doesn\'t match any token: |{}|'.format(word_sofar)
word_sofar = ''
output.write('\n\n')
index += break_len - 1
elif re.match(r'^\s$', text[index]) and not re.match(r'^\s$', word[idx]):
# whitespace found, and whitespace is not part of a word
word_sofar += text[index]
else:
# non-whitespace char, or a whitespace char that's part of a word
word_sofar += text[index]
assert text[index].replace('\n', ' ') == word[idx], "Character mismatch: raw text contains |%s| but the next word is |%s|." % (word_sofar, word)
idx += 1
index += 1
return index, word_sofar
def main(args):
parser = argparse.ArgumentParser()
parser.add_argument('plaintext_file', type=str, help="Plaintext file containing the raw input")
parser.add_argument('conllu_file', type=str, help="CoNLL-U file containing tokens and sentence breaks")
parser.add_argument('-o', '--output', default=None, type=str, help="Output file name; output to the console if not specified (the default)")
parser.add_argument('-m', '--mwt_output', default=None, type=str, help="Output file name for MWT expansions; output to the console if not specified (the default)")
args = parser.parse_args(args=args)
with open(args.plaintext_file, 'r', encoding='utf-8') as f:
text = ''.join(f.readlines())
textlen = len(text)
if args.output is None:
output = sys.stdout
else:
outdir = os.path.split(args.output)[0]
os.makedirs(outdir, exist_ok=True)
output = open(args.output, 'w')
index = 0 # character offset in rawtext
mwt_expansions = []
with open(args.conllu_file, 'r', encoding='utf-8') as f:
buf = ''
mwtbegin = 0
mwtend = -1
expanded = []
last_comments = ""
for line in f:
line = line.strip()
if len(line):
if line[0] == "#":
# comment, don't do anything
if len(last_comments) == 0:
last_comments = line
continue
line = line.split('\t')
if '.' in line[0]:
# the tokenizer doesn't deal with ellipsis
continue
word = line[1]
if '-' in line[0]:
# multiword token
mwtbegin, mwtend = [int(x) for x in line[0].split('-')]
lastmwt = word
expanded = []
elif mwtbegin <= int(line[0]) < mwtend:
expanded += [word]
continue
elif int(line[0]) == mwtend:
expanded += [word]
expanded = [x.lower() for x in expanded] # evaluation doesn't care about case
mwt_expansions += [(lastmwt, tuple(expanded))]
if lastmwt[0].islower() and not expanded[0][0].islower():
print('Sentence ID with potential wrong MWT expansion: ', last_comments, file=sys.stderr)
mwtbegin = 0
mwtend = -1
lastmwt = None
continue
if len(buf):
output.write(buf)
index, word_found = find_next_word(index, text, word, output)
buf = '0' * (len(word_found)-1) + ('1' if '-' not in line[0] else '3')
else:
# sentence break found
if len(buf):
assert int(buf[-1]) >= 1
output.write(buf[:-1] + '{}'.format(int(buf[-1]) + 1))
buf = ''
last_comments = ''
status_line = ""
if args.output:
output.close()
status_line = 'Tokenizer labels written to %s\n ' % args.output
mwts = Counter(mwt_expansions)
if args.mwt_output is None:
print('MWTs:', mwts)
else:
with open(args.mwt_output, 'w') as f:
json.dump(list(mwts.items()), f, indent=2)
status_line = status_line + '{} unique MWTs found in data. MWTs written to {}'.format(len(mwts), args.mwt_output)
print(status_line)
if __name__ == '__main__':
main(sys.argv[1:])