import argparse import json import os import re import sys from collections import Counter """ Data is output in 4 files: a file containing the mwt information a file containing the words and sentences in conllu format a file containing the raw text of each paragraph a file of 0,1,2 indicating word break or sentence break on a character level for the raw text 1: end of word 2: end of sentence """ PARAGRAPH_BREAK = re.compile(r'\n\s*\n') def is_para_break(index, text): """ Detect if a paragraph break can be found, and return the length of the paragraph break sequence. """ if text[index] == '\n': para_break = PARAGRAPH_BREAK.match(text, index) if para_break: break_len = len(para_break.group(0)) return True, break_len return False, 0 def find_next_word(index, text, word, output): """ Locate the next word in the text. In case a paragraph break is found, also write paragraph break to labels. """ idx = 0 word_sofar = '' while index < len(text) and idx < len(word): para_break, break_len = is_para_break(index, text) if para_break: # multiple newlines found, paragraph break if len(word_sofar) > 0: assert re.match(r'^\s+$', word_sofar), 'Found non-empty string at the end of a paragraph that doesn\'t match any token: |{}|'.format(word_sofar) word_sofar = '' output.write('\n\n') index += break_len - 1 elif re.match(r'^\s$', text[index]) and not re.match(r'^\s$', word[idx]): # whitespace found, and whitespace is not part of a word word_sofar += text[index] else: # non-whitespace char, or a whitespace char that's part of a word word_sofar += text[index] assert text[index].replace('\n', ' ') == word[idx], "Character mismatch: raw text contains |%s| but the next word is |%s|." % (word_sofar, word) idx += 1 index += 1 return index, word_sofar def main(args): parser = argparse.ArgumentParser() parser.add_argument('plaintext_file', type=str, help="Plaintext file containing the raw input") parser.add_argument('conllu_file', type=str, help="CoNLL-U file containing tokens and sentence breaks") parser.add_argument('-o', '--output', default=None, type=str, help="Output file name; output to the console if not specified (the default)") parser.add_argument('-m', '--mwt_output', default=None, type=str, help="Output file name for MWT expansions; output to the console if not specified (the default)") args = parser.parse_args(args=args) with open(args.plaintext_file, 'r', encoding='utf-8') as f: text = ''.join(f.readlines()) textlen = len(text) if args.output is None: output = sys.stdout else: outdir = os.path.split(args.output)[0] os.makedirs(outdir, exist_ok=True) output = open(args.output, 'w') index = 0 # character offset in rawtext mwt_expansions = [] with open(args.conllu_file, 'r', encoding='utf-8') as f: buf = '' mwtbegin = 0 mwtend = -1 expanded = [] last_comments = "" for line in f: line = line.strip() if len(line): if line[0] == "#": # comment, don't do anything if len(last_comments) == 0: last_comments = line continue line = line.split('\t') if '.' in line[0]: # the tokenizer doesn't deal with ellipsis continue word = line[1] if '-' in line[0]: # multiword token mwtbegin, mwtend = [int(x) for x in line[0].split('-')] lastmwt = word expanded = [] elif mwtbegin <= int(line[0]) < mwtend: expanded += [word] continue elif int(line[0]) == mwtend: expanded += [word] expanded = [x.lower() for x in expanded] # evaluation doesn't care about case mwt_expansions += [(lastmwt, tuple(expanded))] if lastmwt[0].islower() and not expanded[0][0].islower(): print('Sentence ID with potential wrong MWT expansion: ', last_comments, file=sys.stderr) mwtbegin = 0 mwtend = -1 lastmwt = None continue if len(buf): output.write(buf) index, word_found = find_next_word(index, text, word, output) buf = '0' * (len(word_found)-1) + ('1' if '-' not in line[0] else '3') else: # sentence break found if len(buf): assert int(buf[-1]) >= 1 output.write(buf[:-1] + '{}'.format(int(buf[-1]) + 1)) buf = '' last_comments = '' status_line = "" if args.output: output.close() status_line = 'Tokenizer labels written to %s\n ' % args.output mwts = Counter(mwt_expansions) if args.mwt_output is None: print('MWTs:', mwts) else: with open(args.mwt_output, 'w') as f: json.dump(list(mwts.items()), f, indent=2) status_line = status_line + '{} unique MWTs found in data. MWTs written to {}'.format(len(mwts), args.mwt_output) print(status_line) if __name__ == '__main__': main(sys.argv[1:])